예제 #1
0
if __name__ == '__main__':
    parser = build_flags()
    parser.add_argument('--data_path')
    parser.add_argument('--same_data_norm', action='store_true')
    parser.add_argument('--no_data_norm', action='store_true')
    parser.add_argument('--error_out_name',
                        default='prediction_errors_%dstep.npy')
    parser.add_argument('--prior_variance', type=float, default=5e-5)
    parser.add_argument('--test_burn_in_steps', type=int, default=10)
    parser.add_argument('--error_suffix')
    parser.add_argument('--subject_ind', type=int, default=-1)

    args = parser.parse_args()
    params = vars(args)

    misc.seed(args.seed)

    params['num_vars'] = 3
    params['input_size'] = 4
    params['input_time_steps'] = 50
    params['nll_loss_type'] = 'gaussian'
    train_data = SmallSynthData(args.data_path, 'train', params)
    val_data = SmallSynthData(args.data_path, 'val', params)

    model = model_builder.build_model(params)
    if args.mode == 'train':
        with train_utils.build_writers(args.working_dir) as (train_writer,
                                                             val_writer):
            train.train(model, train_data, val_data, params, train_writer,
                        val_writer)
예제 #2
0
def train(model, train_data, val_data, params, train_writer, val_writer):
    gpu = params.get('gpu', False)
    batch_size = params.get('batch_size', 1000)
    val_batch_size = params.get('val_batch_size', batch_size)
    if val_batch_size is None:
        val_batch_size = batch_size
    accumulate_steps = params.get('accumulate_steps')
    training_scheduler = params.get('training_scheduler', None)
    q_training_scheduler = params.get('training_scheduler', None)
    num_epochs = params.get('num_epochs', 100)
    val_interval = params.get('val_interval', 1)
    val_start = params.get('val_start', 0)
    clip_grad = params.get('clip_grad', None)
    clip_grad_norm = params.get('clip_grad_norm', None)
    normalize_nll = params.get('normalize_nll', False)
    normalize_kl = params.get('normalize_kl', False)
    tune_on_nll = params.get('tune_on_nll', False)
    verbose = params.get('verbose', False)
    val_teacher_forcing = params.get('val_teacher_forcing', False)
    continue_training = params.get('continue_training', False)
    train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    val_data_loader = DataLoader(val_data, batch_size=val_batch_size)
    lr = params['lr']
    wd = params.get('wd', 0.)
    mom = params.get('mom', 0.)
    
    # don't send q_net params to policy optimizer
    for p in model.decoder.q_net.parameters():
        p.requires_grad = False
    for p in model.Q_graph.parameters():
        p.requires_grad = False
    model_params = [param for param in model.parameters() if param.requires_grad]
    if params.get('use_adam', False):
        opt = torch.optim.Adam(model_params, lr=lr, weight_decay=wd)
        q_opt = torch.optim.Adam(list(model.decoder.q_net.parameters()) + list(model.Q_graph.parameters()), lr=lr, weight_decay=wd)
    else:
        opt = torch.optim.SGD(model_params, lr=lr, weight_decay=wd, momentum=mom)
        q_opt = torch.optim.SGD(list(model.decoder.q_net.parameters()) + list(model.Q_graph.parameters()), lr=lr, weight_decay=wd, momentum=mom)

    working_dir = params['working_dir']
    best_path = os.path.join(working_dir, 'best_model')
    checkpoint_dir = os.path.join(working_dir, 'model_checkpoint')
    training_path = os.path.join(working_dir, 'training_checkpoint')
    if continue_training:
        print("RESUMING TRAINING")
        model.load(checkpoint_dir)
        train_params = torch.load(training_path)
        start_epoch = train_params['epoch']
        opt.load_state_dict(train_params['optimizer'])
        q_opt.load_state_dict(train_params['q_optimizer'])
        best_val_result = train_params['best_val_result']
        best_val_epoch = train_params['best_val_epoch']
        print("STARTING EPOCH: ",start_epoch)
    else:
        start_epoch = 1
        best_val_epoch = -1
        best_val_result = 10000000
    
    training_scheduler = train_utils.build_scheduler(opt, params)
    q_training_scheduler = train_utils.build_scheduler(q_opt, params)
    end = start = 0 
    misc.seed(1)
    for epoch in range(start_epoch, num_epochs+1):
        print("EPOCH", epoch, (end-start))
        model.train()
        model.train_percent = epoch / num_epochs
        start = time.time() 
        for batch_ind, batch in enumerate(train_data_loader):
            inputs = batch['inputs']
            if gpu:
                inputs = inputs.cuda(non_blocking=True)
            
            # critic training
            for p in model.decoder.q_net.parameters():
                p.requires_grad = True
            for p in model.Q_graph.parameters():
                p.requires_grad = True
            q_opt.zero_grad()
            opt.zero_grad()
            for _ in range(1):
                loss_critic, loss_nll = model.calculate_loss_q(inputs, is_train=True, return_logits=True)
                loss_critic.backward()
                print(loss_critic,"crit","0.9",epoch)
                print(loss_nll,"nll")
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                q_opt.step()
                q_opt.zero_grad()
                opt.zero_grad()
            for p in model.decoder.q_net.parameters():
                p.requires_grad = False
            for p in model.Q_graph.parameters():
                p.requires_grad = False

            # Finally, update Q_target networks by polyak averaging.
            # We only do it for the q_net, as we don't use the target policy
            with torch.no_grad():
                polyak=0.9
                for p, p_targ in zip(model.decoder.q_net.parameters(), model.decoder_targ.q_net.parameters()):
                    # NB: We use an in-place operations "mul_", "add_" to update target
                    # params, as opposed to "mul" and "add", which would make new tensors.
                    p_targ.data.mul_(polyak)
                    p_targ.data.add_((1 - polyak) * p.data)
                
                for p, p_targ in zip(model.Q_graph.parameters(), model.Q_graph_targ.parameters()):
                    # NB: We use an in-place operations "mul_", "add_" to update target
                    # params, as opposed to "mul" and "add", which would make new tensors.
                    p_targ.data.mul_(polyak)
                    p_targ.data.add_((1 - polyak) * p.data)

            # policy training
            q_opt.zero_grad()
            opt.zero_grad()
            for _ in range(2):
                loss, loss_policy, loss_kl, logits, _ = model.calculate_loss_pi(inputs, is_train=True, return_logits=True)
                loss.backward() 
                print(loss, "pol")
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)      
                opt.step()
                opt.zero_grad()
                q_opt.zero_grad()

            
        if training_scheduler is not None:
            training_scheduler.step()
            q_training_scheduler.step()
        
        if train_writer is not None:
            train_writer.add_scalar('loss', loss.item(), global_step=epoch)
            if normalize_nll:
                train_writer.add_scalar('NLL', loss_nll.mean().item(), global_step=epoch)
            else:
                train_writer.add_scalar('NLL', loss_nll.mean().item()/(inputs.size(1)*inputs.size(2)), global_step=epoch)
            
            train_writer.add_scalar("KL Divergence", loss_kl.mean().item(), global_step=epoch)
        model.eval()
        opt.zero_grad()

        total_nll = 0
        total_kl = 0
        if verbose:
            print("COMPUTING VAL LOSSES")
        with torch.no_grad():
            for batch_ind, batch in enumerate(val_data_loader):
                inputs = batch['inputs']
                if gpu:
                    inputs = inputs.cuda(non_blocking=True)
                loss_critic, loss_nll = model.calculate_loss_q(inputs, is_train=False, teacher_forcing=val_teacher_forcing, return_logits=True)
                loss, loss_policy, loss_kl, logits, _ = model.calculate_loss_pi(inputs, is_train=False, teacher_forcing=val_teacher_forcing, return_logits=True)
                total_kl += loss_kl.sum().item()
                total_nll += loss_nll.sum().item()
                if verbose:
                    print("\tVAL BATCH %d of %d: %f, %f"%(batch_ind+1, len(val_data_loader), loss_nll.mean(), loss_kl.mean()))
            
        total_kl /= len(val_data)
        total_nll /= len(val_data)
        total_loss = model.kl_coef*total_kl + total_nll #TODO: this is a thing you fixed
        if val_writer is not None:
            val_writer.add_scalar('loss', total_loss, global_step=epoch)
            val_writer.add_scalar("NLL", total_nll, global_step=epoch)
            val_writer.add_scalar("KL Divergence", total_kl, global_step=epoch)
        if tune_on_nll:
            tuning_loss = total_nll
        else:
            tuning_loss = total_loss
        if tuning_loss < best_val_result:
            best_val_epoch = epoch
            best_val_result = tuning_loss
            print("BEST VAL RESULT. SAVING MODEL...")
            model.save(best_path)
        model.save(checkpoint_dir)
        torch.save({
                    'epoch':epoch+1,
                    'optimizer':opt.state_dict(),
                    'q_optimizer':q_opt.state_dict(),
                    'best_val_result':best_val_result,
                    'best_val_epoch':best_val_epoch,
                   }, training_path)
        print("EPOCH %d EVAL: "%epoch)
        print("\tCURRENT VAL LOSS: %f"%tuning_loss)
        print("\tBEST VAL LOSS:    %f"%best_val_result)
        print("\tBEST VAL EPOCH:   %d"%best_val_epoch)
        end = time.time()
예제 #3
0
def train(model, train_data, val_data, params, train_writer, val_writer):
    gpu = params.get('gpu', False)
    batch_size = params.get('batch_size', 1000)
    sub_batch_size = params.get('sub_batch_size')
    if sub_batch_size is None:
        sub_batch_size = batch_size
    val_batch_size = params.get('val_batch_size', batch_size)
    if val_batch_size is None:
        val_batch_size = batch_size
    accumulate_steps = params.get('accumulate_steps', 1)
    training_scheduler = params.get('training_scheduler', None)
    num_epochs = params.get('num_epochs', 100)
    val_interval = params.get('val_interval', 1)
    val_start = params.get('val_start', 0)
    clip_grad = params.get('clip_grad', None)
    clip_grad_norm = params.get('clip_grad_norm', None)
    normalize_nll = params.get('normalize_nll', False)
    normalize_kl = params.get('normalize_kl', False)
    tune_on_nll = params.get('tune_on_nll', False)
    verbose = params.get('verbose', False)
    val_teacher_forcing = params.get('val_teacher_forcing', False)
    collate_fn = params.get('collate_fn', None)
    continue_training = params.get('continue_training', False)
    normalize_inputs = params['normalize_inputs']
    num_decoder_samples = 1
    train_data_loader = DataLoader(train_data,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   drop_last=True,
                                   collate_fn=collate_fn)
    print("NUM BATCHES: ", len(train_data_loader))
    val_data_loader = DataLoader(val_data,
                                 batch_size=val_batch_size,
                                 collate_fn=collate_fn)
    lr = params['lr']
    wd = params.get('wd', 0.)
    mom = params.get('mom', 0.)

    model_params = [
        param for param in model.parameters() if param.requires_grad
    ]
    if params.get('use_adam', False):
        opt = torch.optim.Adam(model_params, lr=lr, weight_decay=wd)
    else:
        opt = torch.optim.SGD(model_params,
                              lr=lr,
                              weight_decay=wd,
                              momentum=mom)

    working_dir = params['working_dir']
    best_path = os.path.join(working_dir, 'best_model')
    checkpoint_dir = os.path.join(working_dir, 'model_checkpoint')
    training_path = os.path.join(working_dir, 'training_checkpoint')
    if continue_training:
        print("RESUMING TRAINING")
        model.load(checkpoint_dir)
        train_params = torch.load(training_path)
        start_epoch = train_params['epoch']
        opt.load_state_dict(train_params['optimizer'])
        best_val_result = train_params['best_val_result']
        best_val_epoch = train_params['best_val_epoch']
        model.steps = train_params['step']
        print("STARTING EPOCH: ", start_epoch)
    else:
        start_epoch = 1
        best_val_epoch = -1
        best_val_result = 10000000

    training_scheduler = train_utils.build_scheduler(opt, params)
    end = start = 0
    misc.seed(1)
    for epoch in range(start_epoch, num_epochs + 1):
        model.epoch = epoch
        print("EPOCH", epoch, (end - start))

        model.train_percent = epoch / num_epochs
        start = time.time()
        for batch_ind, batch in enumerate(train_data_loader):
            model.train()
            inputs = batch['inputs']
            masks = batch.get('masks', None)
            node_inds = batch.get('node_inds', None)
            graph_info = batch.get('graph_info', None)
            if gpu:
                inputs = inputs.cuda(non_blocking=True)
                if masks is not None:
                    masks = masks.cuda(non_blocking=True)
            args = {'is_train': True, 'return_logits': True}
            sub_steps = len(range(0, batch_size, sub_batch_size))
            for sub_batch_ind in range(0, batch_size, sub_batch_size):
                sub_inputs = inputs[sub_batch_ind:sub_batch_ind +
                                    sub_batch_size]
                for sample in range(num_decoder_samples):
                    if normalize_inputs:
                        if masks is not None:
                            normalized_inputs = model.normalize_inputs(
                                inputs[:, :-1], masks[:, :-1])
                        else:
                            normalized_inputs = model.normalize_inputs(
                                inputs[:, :-1])
                        args['normalized_inputs'] = normalized_inputs[
                            sub_batch_ind:sub_batch_ind + sub_batch_size]
                    if masks is not None:
                        sub_masks = masks[sub_batch_ind:sub_batch_ind +
                                          sub_batch_size]
                        sub_node_inds = node_inds[sub_batch_ind:sub_batch_ind +
                                                  sub_batch_size]
                        sub_graph_info = graph_info[
                            sub_batch_ind:sub_batch_ind + sub_batch_size]
                        loss, loss_nll, loss_kl, logits, _ = model.calculate_loss(
                            sub_inputs, sub_masks, sub_node_inds,
                            sub_graph_info, **args)
                    else:
                        loss, loss_nll, loss_kl, logits, _ = model.calculate_loss(
                            sub_inputs, **args)
                    loss = loss / (sub_steps * accumulate_steps *
                                   num_decoder_samples)
                    loss.backward()

                if verbose:
                    tmp_batch_ind = batch_ind * sub_steps + sub_batch_ind + 1
                    tmp_total_batch = len(train_data_loader) * sub_steps
                    print("\tBATCH %d OF %d: %f, %f, %f" %
                          (tmp_batch_ind, tmp_total_batch, loss.item(),
                           loss_nll.mean().item(), loss_kl.mean().item()))
            if accumulate_steps == -1 or (batch_ind +
                                          1) % accumulate_steps == 0:
                if verbose and accumulate_steps > 0:
                    print("\tUPDATING WEIGHTS")
                if clip_grad is not None:
                    nn.utils.clip_grad_value_(model.parameters(), clip_grad)
                elif clip_grad_norm is not None:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             clip_grad_norm)
                opt.step()
                model.steps += 1
                opt.zero_grad()
                if accumulate_steps > 0 and accumulate_steps > len(
                        train_data_loader) - batch_ind - 1:
                    break

        if training_scheduler is not None:
            training_scheduler.step()

        if train_writer is not None:
            train_writer.add_scalar(
                'loss',
                loss.item() *
                (sub_steps * accumulate_steps * num_decoder_samples),
                global_step=epoch)
            if normalize_nll:
                train_writer.add_scalar('NLL',
                                        loss_nll.mean().item(),
                                        global_step=epoch)
            else:
                train_writer.add_scalar('NLL',
                                        loss_nll.mean().item() /
                                        (inputs.size(1) * inputs.size(2)),
                                        global_step=epoch)

            train_writer.add_scalar("KL Divergence",
                                    loss_kl.mean().item(),
                                    global_step=epoch)
        if ((epoch + 1) % val_interval != 0):
            end = time.time()
            continue
        model.eval()
        opt.zero_grad()
        total_nll = 0
        total_kl = 0
        if verbose:
            print("COMPUTING VAL LOSSES")
        with torch.no_grad():
            for batch_ind, batch in enumerate(val_data_loader):
                inputs = batch['inputs']
                masks = batch.get('masks', None)
                node_inds = batch.get('node_inds', None)
                graph_info = batch.get('graph_info', None)
                if gpu:
                    inputs = inputs.cuda(non_blocking=True)
                    if masks is not None:
                        masks = masks.cuda(non_blocking=True)
                if masks is not None:
                    loss, loss_nll, loss_kl, logits, _ = model.calculate_loss(
                        inputs,
                        masks,
                        node_inds,
                        graph_info,
                        is_train=False,
                        teacher_forcing=val_teacher_forcing,
                        return_logits=True)
                else:
                    loss, loss_nll, loss_kl, logits, _ = model.calculate_loss(
                        inputs,
                        is_train=False,
                        teacher_forcing=val_teacher_forcing,
                        return_logits=True)
                total_kl += loss_kl.sum().item()
                total_nll += loss_nll.sum().item()
            if verbose:
                print("\tVAL BATCH %d of %d: %f, %f" %
                      (batch_ind + 1, len(val_data_loader), loss_nll.mean(),
                       loss_kl.mean()))

        total_kl /= len(val_data)
        total_nll /= len(val_data)
        total_loss = model.kl_coef * total_kl + total_nll  #TODO: this is a thing you fixed
        #total_loss = total_kl + total_nll
        if val_writer is not None:
            val_writer.add_scalar('loss', total_loss, global_step=epoch)
            val_writer.add_scalar("NLL", total_nll, global_step=epoch)
            val_writer.add_scalar("KL Divergence", total_kl, global_step=epoch)

        if tune_on_nll:
            tuning_loss = total_nll
        else:
            tuning_loss = total_loss
        if tuning_loss < best_val_result:
            best_val_epoch = epoch
            best_val_result = tuning_loss
            print("BEST VAL RESULT. SAVING MODEL...")
            model.save(best_path)
        model.save(checkpoint_dir)
        torch.save(
            {
                'epoch': epoch + 1,
                'optimizer': opt.state_dict(),
                'best_val_result': best_val_result,
                'best_val_epoch': best_val_epoch,
                'step': model.steps,
            }, training_path)
        print("EPOCH %d EVAL: " % epoch)
        print("\tCURRENT VAL LOSS: %f" % tuning_loss)
        print("\tBEST VAL LOSS:    %f" % best_val_result)
        print("\tBEST VAL EPOCH:   %d" % best_val_epoch)

        end = time.time()
예제 #4
0
파일: train.py 프로젝트: Aks-Dmv/idnri
def train(model, train_data, val_data, params, train_writer, val_writer):
    gpu = params.get('gpu', False)
    batch_size = params.get('batch_size', 1000)
    val_batch_size = params.get('val_batch_size', batch_size)
    if val_batch_size is None:
        val_batch_size = batch_size
    accumulate_steps = params.get('accumulate_steps')
    training_scheduler = params.get('training_scheduler', None)
    num_epochs = params.get('num_epochs', 100)
    val_interval = params.get('val_interval', 1)
    val_start = params.get('val_start', 0)
    clip_grad = params.get('clip_grad', None)
    clip_grad_norm = params.get('clip_grad_norm', None)
    normalize_nll = params.get('normalize_nll', False)
    normalize_kl = params.get('normalize_kl', False)
    tune_on_nll = params.get('tune_on_nll', False)
    verbose = params.get('verbose', False)
    val_teacher_forcing = params.get('val_teacher_forcing', False)
    continue_training = params.get('continue_training', False)
    train_data_loader = DataLoader(train_data,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   drop_last=True)
    val_data_loader = DataLoader(val_data, batch_size=val_batch_size)
    lr = params['lr']
    wd = params.get('wd', 0.)
    mom = params.get('mom', 0.)

    disc = dnriModels.DNRI_Disc(params).cuda()
    disc_params = [param for param in disc.parameters() if param.requires_grad]
    model_params = [
        param for param in model.parameters() if param.requires_grad
    ]
    if params.get('use_adam', False):
        disc_opt = torch.optim.Adam(disc_params, lr=lr, weight_decay=wd)
        opt = torch.optim.Adam(model_params, lr=lr, weight_decay=wd)
    else:
        disc_opt = torch.optim.SGD(disc_params,
                                   lr=lr,
                                   weight_decay=wd,
                                   momentum=mom)
        opt = torch.optim.SGD(model_params,
                              lr=lr,
                              weight_decay=wd,
                              momentum=mom)

    working_dir = params['working_dir']
    best_path = os.path.join(working_dir, 'best_model')
    checkpoint_dir = os.path.join(working_dir, 'model_checkpoint')
    training_path = os.path.join(working_dir, 'training_checkpoint')
    if continue_training:
        print("RESUMING TRAINING")
        model.load(checkpoint_dir)
        train_params = torch.load(training_path)
        start_epoch = train_params['epoch']
        opt.load_state_dict(train_params['optimizer'])
        best_val_result = train_params['best_val_result']
        best_val_epoch = train_params['best_val_epoch']
        print("STARTING EPOCH: ", start_epoch)
    else:
        start_epoch = 1
        best_val_epoch = -1
        best_val_result = 10000000

    disc_training_scheduler = train_utils.build_scheduler(disc_opt, params)
    training_scheduler = train_utils.build_scheduler(opt, params)
    end = start = 0
    misc.seed(1)
    CEloss = nn.CrossEntropyLoss()
    for epoch in range(start_epoch, num_epochs + 1):
        #print("EPOCH", epoch, (end-start))
        disc.train()
        disc.train_percent = epoch / num_epochs
        model.train()
        model.train_percent = epoch / num_epochs
        start = time.time()
        for batch_ind, batch in enumerate(train_data_loader):
            inputs = batch['inputs']
            if gpu:
                inputs = inputs.cuda(non_blocking=True)

            loss, loss_nll, loss_kl, logits, all_Preds = model.calculate_loss(
                inputs, is_train=True, return_logits=True, disc=disc)
            x1_x2_pairs = torch.cat(
                [all_Preds[:, :-1, :, :], all_Preds[:, 1:, :, :]],
                dim=-1).detach().clone().cuda()
            discrim_pred = disc(x1_x2_pairs)
            discrim_prob = nn.functional.softmax(discrim_pred, dim=-1)
            disc_logits = logits[:, :-1, :, :].argmax(dim=-1)

            if batch_ind == 0 and (epoch % 99) == 0:
                print("disc_prob/encoder_logits", discrim_prob.shape,
                      disc_logits.shape)
                for i in range(logits.shape[1] - 1):
                    print(
                        "prob/trgt", discrim_prob[1,
                                                  i, :].cpu().detach().argmax(
                                                      dim=-1).numpy(),
                        disc_logits[1, i, :].cpu().detach().numpy())

            discrim_prob = discrim_prob.view(-1, logits.shape[-1])
            disc_logits = disc_logits.flatten().detach().clone().long()

            valid_idx = disc_logits.nonzero().view(-1)
            valid_idx0 = torch.randint(discrim_prob.shape[0],
                                       (valid_idx.shape[0], ))
            final_disc_prob = torch.cat(
                [discrim_prob[valid_idx], discrim_prob[valid_idx0]], dim=0)
            final_disc_logits = torch.cat(
                [disc_logits[valid_idx], disc_logits[valid_idx0]], dim=0)
            disc_loss = CEloss(final_disc_prob, final_disc_logits)
            disc_loss.backward()
            disc_opt.step()
            disc_opt.zero_grad()

            loss.backward()
            if verbose:
                print("\tBATCH %d OF %d: %f, %f, %f" %
                      (batch_ind + 1, len(train_data_loader), loss.item(),
                       loss_nll.mean().item(), loss_kl.mean().item()))
            if accumulate_steps == -1 or (batch_ind +
                                          1) % accumulate_steps == 0:
                if verbose and accumulate_steps > 0:
                    print("\tUPDATING WEIGHTS")
                if clip_grad is not None:
                    nn.utils.clip_grad_value_(model.parameters(), clip_grad)
                elif clip_grad_norm is not None:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             clip_grad_norm)
                opt.step()
                opt.zero_grad()
                if accumulate_steps > 0 and accumulate_steps > len(
                        train_data_loader) - batch_ind - 1:
                    break

        if training_scheduler is not None:
            training_scheduler.step()
            disc_training_scheduler.step()

        if train_writer is not None:
            train_writer.add_scalar('loss', loss.item(), global_step=epoch)
            if normalize_nll:
                train_writer.add_scalar('NLL',
                                        loss_nll.mean().item(),
                                        global_step=epoch)
            else:
                train_writer.add_scalar('NLL',
                                        loss_nll.mean().item() /
                                        (inputs.size(1) * inputs.size(2)),
                                        global_step=epoch)

            train_writer.add_scalar("KL Divergence",
                                    loss_kl.mean().item(),
                                    global_step=epoch)
        model.eval()
        opt.zero_grad()
        disc_opt.zero_grad()

        total_nll = 0
        total_kl = 0
        if verbose:
            print("COMPUTING VAL LOSSES")
        with torch.no_grad():
            for batch_ind, batch in enumerate(val_data_loader):
                inputs = batch['inputs']
                if gpu:
                    inputs = inputs.cuda(non_blocking=True)
                loss, loss_nll, loss_kl, logits, _ = model.calculate_loss(
                    inputs,
                    is_train=False,
                    teacher_forcing=val_teacher_forcing,
                    return_logits=True)
                total_kl += loss_kl.sum().item()
                total_nll += loss_nll.sum().item()
                if verbose:
                    print("\tVAL BATCH %d of %d: %f, %f" %
                          (batch_ind + 1, len(val_data_loader),
                           loss_nll.mean(), loss_kl.mean()))

        total_kl /= len(val_data)
        total_nll /= len(val_data)
        total_loss = model.kl_coef * total_kl + total_nll  #TODO: this is a thing you fixed
        if val_writer is not None:
            val_writer.add_scalar('loss', total_loss, global_step=epoch)
            val_writer.add_scalar("NLL", total_nll, global_step=epoch)
            val_writer.add_scalar("KL Divergence", total_kl, global_step=epoch)
        if tune_on_nll:
            tuning_loss = total_nll
        else:
            tuning_loss = total_loss
        if tuning_loss < best_val_result:
            best_val_epoch = epoch
            best_val_result = tuning_loss
            #print("BEST VAL RESULT. SAVING MODEL...")
            model.save(best_path)
        model.save(checkpoint_dir)
        torch.save(
            {
                'epoch': epoch + 1,
                'optimizer': opt.state_dict(),
                'best_val_result': best_val_result,
                'best_val_epoch': best_val_epoch,
            }, training_path)
        #print("EPOCH %d EVAL: "%epoch)
        #print("\tCURRENT VAL LOSS: %f"%tuning_loss)
        #print("\tBEST VAL LOSS:    %f"%best_val_result)
        #print("\tBEST VAL EPOCH:   %d"%best_val_epoch)
        end = time.time()