Exemple #1
0
def load_model_from_file(save_folder, input_shape, device):
    meta_file = os.path.join(save_folder, 'metadata.pkl')
    model_file = os.path.join(save_folder, 'model.pt')
    train_args = pickle.load(open(meta_file, 'rb'))['args']
    model = modules.ContrastiveSWM(embedding_dim=train_args.embedding_dim,
                                   hidden_dim=train_args.hidden_dim,
                                   action_dim=train_args.action_dim,
                                   input_dims=input_shape,
                                   num_objects=train_args.num_objects,
                                   sigma=train_args.sigma,
                                   hinge=train_args.hinge,
                                   ignore_action=train_args.ignore_action,
                                   copy_action=train_args.copy_action,
                                   encoder=train_args.encoder).to(device)
    model.load_state_dict(torch.load(model_file))
    return model, train_args
Exemple #2
0
def evaluate(args, args_eval, model_file):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    ex = None
    if args_eval.sacred:
        from sacred import Experiment
        from sacred.observers import MongoObserver
        ex = Experiment(args_eval.sacred_name)
        ex.observers.append(
            MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME))
        ex.add_config({
            "batch_size": args.batch_size,
            "epochs": args.epochs,
            "learning_rate": args.learning_rate,
            "encoder": args.encoder,
            "num_objects": args.num_objects,
            "custom_neg": args.custom_neg,
            "in_ep_prob": args.in_ep_prob,
            "seed": args.seed,
            "dataset": args.dataset,
            "save_folder": args.save_folder,
            "eval_dataset": args_eval.dataset,
            "num_steps": args_eval.num_steps,
            "use_action_attention": args.use_action_attention
        })

    device = torch.device('cuda' if args.cuda else 'cpu')

    dataset = utils.PathDatasetStateIds(hdf5_file=args.dataset,
                                        path_length=args_eval.num_steps)
    eval_loader = data.DataLoader(dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=4)

    # Get data sample
    obs = eval_loader.__iter__().next()[0]
    input_shape = obs[0][0].size()

    model = modules.ContrastiveSWM(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        action_dim=args.action_dim,
        input_dims=input_shape,
        num_objects=args.num_objects,
        sigma=args.sigma,
        hinge=args.hinge,
        ignore_action=args.ignore_action,
        copy_action=args.copy_action,
        split_mlp=args.split_mlp,
        same_ep_neg=args.same_ep_neg,
        only_same_ep_neg=args.only_same_ep_neg,
        immovable_bit=args.immovable_bit,
        split_gnn=args.split_gnn,
        no_loss_first_two=args.no_loss_first_two,
        bisim_model=make_pairwise_encoder() if args.bisim_model_path else None,
        encoder=args.encoder,
        use_action_attention=args.use_action_attention).to(device)

    model.load_state_dict(torch.load(model_file))
    model.eval()

    # topk = [1, 5, 10]
    topk = [1]
    hits_at = defaultdict(int)
    num_samples = 0
    rr_sum = 0

    pred_states = []
    next_states = []
    next_ids = []

    with torch.no_grad():

        for batch_idx, data_batch in enumerate(eval_loader):
            data_batch = [[t.to(device) for t in tensor]
                          for tensor in data_batch]
            observations, actions, state_ids = data_batch

            if observations[0].size(0) != args.batch_size:
                continue

            obs = observations[0]
            next_obs = observations[-1]
            next_id = state_ids[-1]

            state = model.obj_encoder(model.obj_extractor(obs))
            next_state = model.obj_encoder(model.obj_extractor(next_obs))

            pred_state = state
            for i in range(args_eval.num_steps):
                pred_state = model.forward_transition(pred_state, actions[i])

            pred_states.append(pred_state.cpu())
            next_states.append(next_state.cpu())
            next_ids.append(next_id.cpu().numpy())

        pred_state_cat = torch.cat(pred_states, dim=0)
        next_state_cat = torch.cat(next_states, dim=0)
        next_ids_cat = np.concatenate(next_ids, axis=0)

        full_size = pred_state_cat.size(0)

        # Flatten object/feature dimensions
        next_state_flat = next_state_cat.view(full_size, -1)
        pred_state_flat = pred_state_cat.view(full_size, -1)

        dist_matrix = utils.pairwise_distance_matrix(next_state_flat,
                                                     pred_state_flat)

        #num_digits = 1
        #dist_matrix = (dist_matrix * 10 ** num_digits).round() / (10 ** num_digits)
        #dist_matrix = dist_matrix.float()

        dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1)
        dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix],
                                          dim=1)

        # Workaround to get a stable sort in numpy.
        dist_np = dist_matrix_augmented.numpy()
        indices = []
        for row in dist_np:
            keys = (np.arange(len(row)), row)
            indices.append(np.lexsort(keys))
        indices = np.stack(indices, axis=0)

        if args_eval.dedup:
            mask_mistakes = indices[:, 0] != 0
            closest_next_ids = next_ids_cat[indices[:, 0] - 1]

            if len(next_ids_cat.shape) == 2:
                equal_mask = np.all(closest_next_ids == next_ids_cat, axis=1)
            else:
                equal_mask = closest_next_ids == next_ids_cat

            indices[:, 0][np.logical_and(equal_mask, mask_mistakes)] = 0

        indices = torch.from_numpy(indices).long()

        #print('Processed {} batches of size {}'.format(
        #    batch_idx + 1, args.batch_size))

        labels = torch.zeros(indices.size(0),
                             device=indices.device,
                             dtype=torch.int64).unsqueeze(-1)

        num_samples += full_size
        #print('Size of current topk evaluation batch: {}'.format(
        #    full_size))

        for k in topk:
            match = indices[:, :k] == labels
            num_matches = match.sum()
            hits_at[k] += num_matches.item()

        match = indices == labels
        _, ranks = match.max(1)

        reciprocal_ranks = torch.reciprocal(ranks.double() + 1)
        rr_sum += reciprocal_ranks.sum().item()

        pred_states = []
        next_states = []
        next_ids = []

    hits = hits_at[topk[0]] / float(num_samples)
    mrr = rr_sum / float(num_samples)

    if ex is not None:
        # ugly hack
        @ex.main
        def sacred_main():
            ex.log_scalar("hits", hits)
            ex.log_scalar("mrr", mrr)

        ex.run()

    print('Hits @ {}: {}'.format(topk[0], hits))
    print('MRR: {}'.format(mrr))

    return hits, mrr
Exemple #3
0
dataset = utils.StateTransitionsDataset(hdf5_file=args.dataset)
train_loader = data.DataLoader(dataset,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=4)

# Get data sample
obs = train_loader.__iter__().next()[0]
input_shape = obs[0].size()

model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim,
                               hidden_dim=args.hidden_dim,
                               action_dim=args.action_dim,
                               input_dims=input_shape,
                               num_objects=args.num_objects,
                               sigma=args.sigma,
                               hinge=args.hinge,
                               ignore_action=args.ignore_action,
                               copy_action=args.copy_action,
                               encoder=args.encoder).to(device)

model.apply(utils.weights_init)

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

if args.decoder:
    if args.encoder == 'large':
        decoder = modules.DecoderCNNLarge(input_dim=args.embedding_dim,
                                          num_objects=args.num_objects,
                                          hidden_dim=args.hidden_dim // 16,
                                          output_size=input_shape).to(device)
Exemple #4
0
def evaluate(args, args_eval, model_file):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    ex = None
    if args_eval.sacred:
        from sacred import Experiment
        from sacred.observers import MongoObserver
        ex = Experiment(args_eval.sacred_name)
        ex.observers.append(MongoObserver(url=constants.MONGO_URI, db_name=constants.DB_NAME))
        ex.add_config({
            "batch_size": args.batch_size,
            "epochs": args.epochs,
            "learning_rate": args.learning_rate,
            "encoder": args.encoder,
            "num_objects": args.num_objects,
            "custom_neg": args.custom_neg,
            "in_ep_prob": args.in_ep_prob,
            "seed": args.seed,
            "dataset": args.dataset,
            "save_folder": args.save_folder,
            "eval_dataset": args_eval.dataset,
            "num_steps": args_eval.num_steps,
            "use_action_attention": args.use_action_attention
        })

    device = torch.device('cuda' if args.cuda else 'cpu')

    dataset = utils.PathDatasetStateIds(
        hdf5_file=args.dataset, path_length=10)
    eval_loader = data.DataLoader(
        dataset, batch_size=100, shuffle=False, num_workers=4)

    # Get data sample
    obs = eval_loader.__iter__().next()[0]
    input_shape = obs[0][0].size()

    model = modules.ContrastiveSWM(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        action_dim=args.action_dim,
        input_dims=input_shape,
        num_objects=args.num_objects,
        sigma=args.sigma,
        hinge=args.hinge,
        ignore_action=args.ignore_action,
        copy_action=args.copy_action,
        split_mlp=args.split_mlp,
        same_ep_neg=args.same_ep_neg,
        only_same_ep_neg=args.only_same_ep_neg,
        immovable_bit=args.immovable_bit,
        split_gnn=args.split_gnn,
        no_loss_first_two=args.no_loss_first_two,
        bisim_model=make_pairwise_encoder() if args.bisim_model_path else None,
        encoder=args.encoder,
        use_action_attention=args.use_action_attention
    ).to(device)

    model.load_state_dict(torch.load(model_file))
    model.eval()

    hits_list = []

    with torch.no_grad():

        for batch_idx, data_batch in enumerate(eval_loader):

            data_batch = [[t.to(
                device) for t in tensor] for tensor in data_batch]

            observations, actions, state_ids = data_batch

            if observations[0].size(0) != args.batch_size:
                continue

            states = []
            for obs in observations:
                states.append(model.obj_encoder(model.obj_extractor(obs)))
            states = torch.stack(states, dim=0)
            state_ids = torch.stack(state_ids, dim=0)

            pred_state = states[0]
            if not args_eval.no_transition:
                for i in range(args_eval.num_steps):
                    pred_state = model.forward_transition(pred_state, actions[i])

            # pred_state: [100, |O|, D]
            # states: [10, 100, |O|, D]
            # pred_state_flat: [100, X]
            # states_flat: [10, 100, X]
            pred_state_flat = pred_state.reshape((pred_state.size(0), pred_state.size(1) * pred_state.size(2)))
            states_flat = states.reshape((states.size(0), states.size(1), states.size(2) * states.size(3)))

            # dist_matrix: [10, 100]
            dist_matrix = (states_flat - pred_state_flat[None]).pow(2).sum(2)
            indices = torch.argmin(dist_matrix, dim=0)
            correct = indices == args_eval.num_steps

            # print(indices[0], args_eval.num_steps)
            # observations = torch.stack(observations, dim=0)
            # correct_obs = observations[args_eval.num_steps, 0]
            # pred_obs = observations[indices[0], 0]
            # import matplotlib
            # matplotlib.use("TkAgg")
            # import matplotlib.pyplot as plt
            # plt.subplot(1, 2, 1)
            # plt.imshow(correct_obs.cpu().numpy()[3:].transpose((1, 2, 0)))
            # plt.subplot(1, 2, 2)
            # plt.imshow(pred_obs.cpu().numpy()[3:].transpose((1, 2, 0)))
            # plt.show()

            # check for duplicates
            if args_eval.dedup:
                equal_mask = torch.all(state_ids[indices, list(range(100))] == state_ids[args_eval.num_steps], dim=1)
                correct = correct + equal_mask

            # hits
            hits_list.append(correct.float().mean().item())

    hits = np.mean(hits_list)

    if ex is not None:
        # ugly hack
        @ex.main
        def sacred_main():
            ex.log_scalar("hits", hits)
            ex.log_scalar("mrr", 0.)

        ex.run()

    print('Hits @ 1: {}'.format(hits))

    return hits, 0.
Exemple #5
0
def train_and_eval(args, eval_every, use_trans_model, ft_data_loader,
                   ft_eval_data_loader):
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    now = datetime.datetime.now()
    timestamp = now.isoformat()

    if args.name == 'none':
        exp_name = timestamp
    else:
        exp_name = args.name

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    exp_counter = 0
    save_folder = '{}/{}/'.format(args.save_folder, exp_name)

    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    meta_file = os.path.join(save_folder, 'metadata.pkl')
    model_file = os.path.join(save_folder, 'model.pt')
    log_file = os.path.join(save_folder, 'log.txt')

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger()
    logger.addHandler(logging.FileHandler(log_file, 'a'))
    print = logger.info

    pickle.dump({'args': args}, open(meta_file, "wb"))

    device = torch.device('cuda' if args.cuda else 'cpu')
    print("About to get dataset")

    transform = None
    if args.data_aug:
        transform = utils.get_data_augmentation()
        dataset = utils.StateTransitionsDataAugDataset(hdf5_file=args.dataset,
                                                       transforms=transform)
    else:
        dataset = utils.StateTransitionsDataset(hdf5_file=args.dataset)

    train_loader = data.DataLoader(dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers)
    print("Dataset loaded")

    obs = train_loader.__iter__().next()[0]
    input_shape = obs[0].size()

    model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim,
                                   hidden_dim=args.hidden_dim,
                                   action_dim=args.action_dim,
                                   input_dims=input_shape,
                                   num_objects=args.num_objects,
                                   sigma=args.sigma,
                                   hinge=args.hinge,
                                   ignore_action=args.ignore_action,
                                   copy_action=args.copy_action,
                                   encoder=args.encoder,
                                   use_nt_xent_loss=args.use_nt_xent_loss,
                                   temperature=args.temperature,
                                   use_slot_attn=args.slot_attn).to(device)

    model.apply(utils.weights_init)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    # Train model.
    print('Starting model training...')
    step = 0
    best_loss = 1e9

    epoch_acc_list = []
    for epoch in range(1, args.epochs + 1):
        model.train()
        train_loss = 0

        for batch_idx, data_batch in enumerate(train_loader):
            data_batch = [tensor.to(device) for tensor in data_batch]
            optimizer.zero_grad()

            if model.use_nt_xent_loss:
                loss = model.nt_xent_loss(*data_batch)
            else:
                loss = model.contrastive_loss(*data_batch)

            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            if args.decoder:
                optimizer_dec.step()

            if batch_idx % args.log_interval == 0:
                print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data_batch[0]),
                    len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(data_batch[0])))

            step += 1

        avg_loss = train_loss / len(train_loader.dataset)
        print('====> Epoch: {} Average loss: {:.6f}'.format(epoch, avg_loss))

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), model_file)

        if epoch % eval_every == 0 or epoch == args.epochs:
            #Copy model for fine tuning
            model_clone = modules.ContrastiveSWM(
                embedding_dim=args.embedding_dim,
                hidden_dim=args.hidden_dim,
                action_dim=args.action_dim,
                input_dims=input_shape,
                num_objects=args.num_objects,
                sigma=args.sigma,
                hinge=args.hinge,
                ignore_action=args.ignore_action,
                copy_action=args.copy_action,
                encoder=args.encoder,
                use_nt_xent_loss=args.use_nt_xent_loss,
                temperature=args.temperature,
                use_slot_attn=args.slot_attn).to('cpu')
            model_clone.load_state_dict(copy.deepcopy(
                model.state_dict()))  #Deepcopy does not work on model
            model_clone.to(device)
            ft_acc_list = fine_tune_and_eval_downstream(
                model_clone,
                device,
                ft_data_loader,
                ft_eval_data_loader,
                10,
                acc_every=6,
                use_trans_model=use_trans_model,
                epochs=60,
                learning_rate=5e-4)
            model_clone.to('cpu')
            #Get best accuracy from list and use as the evaluation result for this training epoch
            best_ft_acc = max(ft_acc_list)
            epoch_acc_list.append((epoch, best_ft_acc))
            print('[Epoch %d] test acc: %.3f' % (epoch, best_ft_acc))

    return epoch_acc_list
Exemple #6
0
    bisim_model.load_state_dict(torch.load(args.bisim_model_path))

model = modules.ContrastiveSWM(
    embedding_dim=args.embedding_dim,
    hidden_dim=args.hidden_dim,
    action_dim=args.action_dim,
    input_dims=input_shape,
    num_objects=args.num_objects,
    sigma=args.sigma,
    hinge=args.hinge,
    ignore_action=args.ignore_action,
    copy_action=args.copy_action,
    split_mlp=args.split_mlp,
    same_ep_neg=args.same_ep_neg,
    only_same_ep_neg=args.only_same_ep_neg,
    immovable_bit=args.immovable_bit,
    split_gnn=args.split_gnn,
    no_loss_first_two=args.no_loss_first_two,
    gamma=args.gamma,
    bisim_metric=bisim_metric,
    bisim_eps=args.bisim_eps,
    next_state_neg=args.next_state_neg,
    nl_type=args.nl_type,
    encoder=args.encoder,
    use_coord_grid=args.coord_grid,
    many_negs=args.num_negs != 1 or args.mix_negs,
    detach_negs=args.detach_negs,
    mix_negs=args.mix_negs,
    use_action_attention=args.use_action_attention).to(device)

model.apply(utils.weights_init)
Exemple #7
0
                              shuffle=False,
                              num_workers=4)

# Get data sample
obs = eval_loader.__iter__().next()[0]
input_shape = obs[0][0].size()

model = modules.ContrastiveSWM(
    embedding_dim=args.embedding_dim,
    hidden_dim=args.hidden_dim,
    action_dim=args.action_dim,
    input_dims=input_shape,
    num_objects=args.num_objects,
    sigma=args.sigma,
    hinge=args.hinge,
    ignore_action=args.ignore_action,
    copy_action=args.copy_action,
    split_mlp=args.split_mlp,
    same_ep_neg=args.same_ep_neg,
    only_same_ep_neg=args.only_same_ep_neg,
    immovable_bit=args.immovable_bit,
    split_gnn=args.split_gnn,
    no_loss_first_two=args.no_loss_first_two,
    encoder=args.encoder,
    use_action_attention=args.use_action_attention).to(device)

model.load_state_dict(torch.load(model_file))
model.eval()

all_states = []

with torch.no_grad():
Exemple #8
0
#     action_dim=args.action_dim,
#     input_dims=input_shape,
#     num_objects=args.num_objects,
#     sigma=args.sigma,
#     hinge=args.hinge,
#     ignore_action=args.ignore_action,
#     copy_action=args.copy_action,
#     encoder=args.encoder).to(device)

model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim,
                               hidden_dim=args.hidden_dim,
                               action_dim=args.action_dim,
                               input_dims=input_shape,
                               num_objects=args.num_objects,
                               sigma=args.sigma,
                               hinge=args.hinge,
                               ignore_action=args.ignore_action,
                               copy_action=args.copy_action,
                               encoder=args.encoder,
                               use_nt_xent_loss=args.use_nt_xent_loss,
                               temperature=args.temperature,
                               use_slot_attn=args.slot_attn).to(device)

model.load_state_dict(torch.load(model_file))
model.eval()

# topk = [1, 5, 10]
topk = [1]
hits_at = defaultdict(int)
num_samples = 0
rr_sum = 0
Exemple #9
0
def train_c_swm(args):
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print("Inside train_c_swm")

    now = datetime.datetime.now()
    timestamp = now.isoformat()

    if args.name == 'none':
        exp_name = timestamp
    else:
        exp_name = args.name

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    exp_counter = 0
    save_folder = '{}/{}/'.format(args.save_folder, exp_name)

    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    meta_file = os.path.join(save_folder, 'metadata.pkl')
    model_file = os.path.join(save_folder, 'model.pt')
    log_file = os.path.join(save_folder, 'log.txt')

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger()
    logger.addHandler(logging.FileHandler(log_file, 'a'))
    #print = logger.info

    pickle.dump({'args': args}, open(meta_file, "wb"))

    device = torch.device('cuda' if args.cuda else 'cpu')
    print("About to get dataset")
    dataset = utils.StateTransitionsDataset(hdf5_file=args.dataset)
    train_loader = data.DataLoader(dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers)
    print("Dataset loaded")
    # Get data sample
    obs = train_loader.__iter__().next()[0]
    input_shape = obs[0].size()

    model = modules.ContrastiveSWM(embedding_dim=args.embedding_dim,
                                   hidden_dim=args.hidden_dim,
                                   action_dim=args.action_dim,
                                   input_dims=input_shape,
                                   num_objects=args.num_objects,
                                   sigma=args.sigma,
                                   hinge=args.hinge,
                                   ignore_action=args.ignore_action,
                                   copy_action=args.copy_action,
                                   encoder=args.encoder).to(device)

    model.apply(utils.weights_init)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    if args.decoder:
        if args.encoder == 'large':
            decoder = modules.DecoderCNNLarge(
                input_dim=args.embedding_dim,
                num_objects=args.num_objects,
                hidden_dim=args.hidden_dim // 16,
                output_size=input_shape).to(device)
        elif args.encoder == 'medium':
            decoder = modules.DecoderCNNMedium(
                input_dim=args.embedding_dim,
                num_objects=args.num_objects,
                hidden_dim=args.hidden_dim // 16,
                output_size=input_shape).to(device)
        elif args.encoder == 'small':
            decoder = modules.DecoderCNNSmall(
                input_dim=args.embedding_dim,
                num_objects=args.num_objects,
                hidden_dim=args.hidden_dim // 16,
                output_size=input_shape).to(device)
        decoder.apply(utils.weights_init)
        optimizer_dec = torch.optim.Adam(decoder.parameters(),
                                         lr=args.learning_rate)

    # Train model.
    print('Starting model training...')
    step = 0
    best_loss = 1e9

    for epoch in range(1, args.epochs + 1):
        model.train()
        train_loss = 0

        for batch_idx, data_batch in enumerate(train_loader):
            data_batch = [tensor.to(device) for tensor in data_batch]
            optimizer.zero_grad()

            if args.decoder:
                optimizer_dec.zero_grad()
                obs, action, next_obs = data_batch
                objs = model.obj_extractor(obs)
                state = model.obj_encoder(objs)

                rec = torch.sigmoid(decoder(state))
                loss = F.binary_cross_entropy(rec, obs,
                                              reduction='sum') / obs.size(0)

                next_state_pred = state + model.transition_model(state, action)
                next_rec = torch.sigmoid(decoder(next_state_pred))
                next_loss = F.binary_cross_entropy(
                    next_rec, next_obs, reduction='sum') / obs.size(0)
                loss += next_loss
            else:
                loss = model.contrastive_loss(*data_batch)

            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            if args.decoder:
                optimizer_dec.step()

            if batch_idx % args.log_interval == 0:
                print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data_batch[0]),
                    len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(data_batch[0])))

            step += 1

        avg_loss = train_loss / len(train_loader.dataset)
        # print('====> Epoch: {} Average loss: {:.6f}'.format(
        #     epoch, avg_loss))

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), model_file)

    return model