示例#1
0
def main(params):
    params['anchor_list'] = ANCHOR_LIST
    logger = logging.getLogger(params['alias'])
    set_device(logger, params['gpu_id'])
    saver = ModelSaver(params, os.path.abspath('./third_party/densevid_eval'))
    model = construct_model(params, saver, logger)

    model = model.cuda()

    training_set = ANetDataSample(params['train_data'], params['feature_path'],
                                  params['translator_path'], params['video_sample_rate'], logger)
    val_set = ANetDataFull(params['val_data'], params['feature_path'],
                           params['translator_path'], params['video_sample_rate'], logger)
    train_loader = DataLoader(training_set, batch_size=params['batch_size'], shuffle=True,
                              num_workers=params['num_workers'], collate_fn=collate_fn, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=params['batch_size']/3, shuffle=True,
                            num_workers=params['num_workers'], collate_fn=collate_fn, drop_last=True)

    optimizer = torch.optim.SGD(model.get_parameter_group(params),
                                lr=params['lr'], weight_decay=params['weight_decay'], momentum=params['momentum'])

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=params['lr_step'], gamma=params["lr_decay_rate"])
    eval(model, val_loader, params, logger, 0, saver)
    for step in range(params['training_epoch']):
        lr_scheduler.step()
        train(model, train_loader, params, logger, step, optimizer)

        # validation and saving
        if step % params['test_interval'] == 0:
            eval(model, val_loader, params, logger, step, saver)
        if step % params['save_model_interval'] == 0 and step != 0:
            saver.save_model(model,step, {'step': step})
def main(params):
    logger = logging.getLogger(params['alias'])
    gpu_id = set_device(logger, params['gpu_id'])
    logger = logging.getLogger(params['alias'] + '(%d)' % gpu_id)
    saver = ModelSaver(params, os.path.abspath('./third_party/densevid_eval'))
    model = construct_model(params, saver, logger)

    model = model.cuda()

    val_set = ANetDataSample(params['val_data'], params['feature_path'],
                             params['translator_path'],
                             params['video_sample_rate'], logger)
    val_loader = DataLoader(val_set,
                            batch_size=params['batch_size'],
                            shuffle=True,
                            num_workers=params['num_workers'],
                            collate_fn=collate_fn,
                            drop_last=True)

    eval(model, val_loader, params, logger, 0, saver)
示例#3
0
def main(params):
    logger = logging.getLogger(params['alias'])
    set_device(logger, params['gpu_id'])
    saver = ModelSaver(params, os.path.abspath('./third_party/densevid_eval'))
    torch.manual_seed(params['rand_seed'])
    model_sl, model_cg = construct_model(params, saver, logger)

    model_sl, model_cg = model_sl.cuda(), model_cg.cuda()

    val_set = ANetDataFull(params['val_data'], params['feature_path'],
                           params['translator_path'],
                           params['video_sample_rate'], logger)

    val_loader = DataLoader(val_set,
                            batch_size=params['batch_size'],
                            shuffle=True,
                            num_workers=params['num_workers'],
                            collate_fn=collate_fn,
                            drop_last=True)
    print(params)
    eval(model_sl, model_cg, val_loader, logger, saver, params)
示例#4
0
def train(
    main_config,
    model_config,
    model_name,
    experiment_name,
    dataset_name,
):
    main_cfg = MainConfig(main_config)
    model = MODELS[model_name]
    dataset = dataset_type.get_dataset(dataset_name)

    train_data = dataset.train_set_pairs()
    vectorizer = DatasetVectorizer(main_cfg.model_dir,
                                   raw_sentence_pairs=train_data)

    dataset_helper = Dataset(vectorizer, dataset, main_cfg.batch_size)
    max_sentence_len = vectorizer.max_sentence_len
    vocabulary_size = vectorizer.vocabulary_size

    train_mini_sen1, train_mini_sen2, train_mini_labels = dataset_helper.pick_train_mini_batch(
    )
    train_mini_labels = train_mini_labels.reshape(-1, 1)

    test_sentence1, test_sentence2 = dataset_helper.test_instances()
    test_labels = dataset_helper.test_labels()
    test_labels = test_labels.reshape(-1, 1)

    num_batches = dataset_helper.num_batches
    model = model(
        max_sentence_len,
        vocabulary_size,
        main_config,
        model_config,
    )
    model_saver = ModelSaver(
        main_cfg.model_dir,
        experiment_name,
        main_cfg.checkpoints_to_keep,
    )
    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=main_cfg.log_device_placement,
    )

    with tf.Session(config=config) as session:
        global_step = 0
        init = tf.global_variables_initializer()
        session.run(init)
        log_saver = LogSaver(
            main_cfg.logs_path,
            experiment_name,
            dataset_name,
            session.graph,
        )
        model_evaluator = ModelEvaluator(model, session)

        metrics = {'acc': 0.0}
        time_per_epoch = []

        log('Training model for {} epochs'.format(main_cfg.num_epochs))
        for epoch in tqdm(range(main_cfg.num_epochs), desc='Epochs'):
            start_time = time.time()

            train_sentence1, train_sentence2 = dataset_helper.train_instances(
                shuffle=True)
            train_labels = dataset_helper.train_labels()

            train_batch_helper = BatchHelper(
                train_sentence1,
                train_sentence2,
                train_labels,
                main_cfg.batch_size,
            )

            # small eval set for measuring dev accuracy
            dev_sentence1, dev_sentence2, dev_labels = dataset_helper.dev_instances(
            )
            dev_labels = dev_labels.reshape(-1, 1)

            tqdm_iter = tqdm(range(num_batches),
                             total=num_batches,
                             desc="Batches",
                             leave=False,
                             postfix=metrics)
            for batch in tqdm_iter:
                global_step += 1
                sentence1_batch, sentence2_batch, labels_batch = train_batch_helper.next(
                    batch)
                feed_dict_train = {
                    model.x1: sentence1_batch,
                    model.x2: sentence2_batch,
                    model.is_training: True,
                    model.labels: labels_batch,
                }
                loss, _ = session.run([model.loss, model.opt],
                                      feed_dict=feed_dict_train)

                if batch % main_cfg.eval_every == 0:
                    feed_dict_train = {
                        model.x1: train_mini_sen1,
                        model.x2: train_mini_sen2,
                        model.is_training: False,
                        model.labels: train_mini_labels,
                    }

                    train_accuracy, train_summary = session.run(
                        [model.accuracy, model.summary_op],
                        feed_dict=feed_dict_train,
                    )
                    log_saver.log_train(train_summary, global_step)

                    feed_dict_dev = {
                        model.x1: dev_sentence1,
                        model.x2: dev_sentence2,
                        model.is_training: False,
                        model.labels: dev_labels
                    }

                    dev_accuracy, dev_summary = session.run(
                        [model.accuracy, model.summary_op],
                        feed_dict=feed_dict_dev,
                    )
                    log_saver.log_dev(dev_summary, global_step)
                    tqdm_iter.set_postfix(
                        dev_acc='{:.2f}'.format(float(dev_accuracy)),
                        train_acc='{:.2f}'.format(float(train_accuracy)),
                        loss='{:.2f}'.format(float(loss)),
                        epoch=epoch)

                if global_step % main_cfg.save_every == 0:
                    model_saver.save(session, global_step=global_step)

            model_evaluator.evaluate_dev(dev_sentence1, dev_sentence2,
                                         dev_labels)

            end_time = time.time()
            total_time = timer(start_time, end_time)
            time_per_epoch.append(total_time)

            model_saver.save(session, global_step=global_step)

        model_evaluator.evaluate_test(test_sentence1, test_sentence2,
                                      test_labels)
        model_evaluator.save_evaluation(
            '{}/{}'.format(main_cfg.model_dir, experiment_name),
            time_per_epoch[-1], dataset)
示例#5
0
def training(train_dataset, **kwargs):
    train_start_time = datetime.now().strftime('%Y%m%d-%H%M%S')
    logger.debug('Set parameters and the model. Start training time is: %s' %
                 train_start_time)

    model = kwargs['model']
    loss = kwargs['loss']
    optimizer = kwargs['optimizer']

    batch_time = Meter()
    data_time = Meter()
    loss_meter = Meter()

    # do not modify opt.lr, use everywhere here adjustible_lr
    adjustable_lr = opt.lr
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=opt.num_workers,
                                                   drop_last=False)

    logger.debug('Starting training for %d epochs:' % opt.epochs)

    model_saver_val = ModelSaver(
        path=join(opt.storage, 'models', opt.model_name, opt.log_name, 'val'))
    # model_saver_test = ModelSaver(path=join(opt.storage, 'models', kwargs['name'], opt.log_name, 'test'))
    # if you need test set validation, add corresponding functions below similarly to val set

    for epoch in range(opt.epochs):
        model.to(opt.device)
        model.train()
        train_dataset.epoch = epoch

        logger.debug('Epoch # %d' % epoch)

        # adjust learning rate if necessary
        if epoch and epoch % 50 == 0:  # every 50 epochs
            adjustable_lr = adjust_lr(optimizer, adjustable_lr)

        if epoch == 0:
            testing(train_dataset,
                    model,
                    loss,
                    epoch=-1,
                    mode='train',
                    time_id=train_start_time)

        end = time.time()

        n_train_samples = 0
        for i, input in enumerate(train_dataloader):
            data_time.update(time.time() - end)
            labels = input['labels']

            if len(labels) == 1: raise EnvironmentError('LAZY ANNA')

            output = model(input)
            loss_values = loss(output, input)
            loss_meter.update(loss_values.item(), len(labels))

            optimizer.zero_grad()
            loss_values.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            n_train_samples += len(labels)

            if i % opt.debug_freq == 0 and i:
                logger.debug(
                    'Epoch: [{0}][{1}/{2}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\tLoss {loss.val:.4f} ({loss.avg:.4f})\t'
                    .format(epoch,
                            i,
                            len(train_dataloader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=loss_meter))

        logger.debug('Number of training sampler within one epoch: %d' %
                     n_train_samples)
        logger.debug('Loss: %f' % loss_meter.avg)
        loss_meter.reset()

        if opt.test_freq and epoch % opt.test_freq == 0:
            # test the model every opt.test_freq epoch
            testing(train_dataset,
                    model,
                    loss,
                    epoch=epoch,
                    mode='train',
                    time_id=train_start_time)

        if opt.save_model and epoch % opt.save_model == 0:
            if opt.test_val:
                check_val = testing(kwargs['val_dataset'],
                                    model,
                                    loss,
                                    epoch=epoch,
                                    mode='val',
                                    time_id=train_start_time)
                if model_saver_val.check(check_val):
                    save_dict = {
                        'epoch': epoch,
                        'state_dict': copy.deepcopy(model.state_dict()),
                        'optimizer':
                        copy.deepcopy(optimizer.state_dict().copy())
                    }

            logger.debug(opt.log_name)

        if opt.save_model:
            model_saver_val.save()

    # save the last checkpoint of the training
    if opt.save_model:
        save_dict = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        os.makedirs(join(opt.storage, 'models', opt.model_name, opt.log_name),
                    exist_ok=True)
        torch.save(
            save_dict,
            join(opt.storage, 'models', opt.model_name, opt.log_name,
                 'last_%d.pth.tar' % epoch))

    return model
示例#6
0
def train_net(net, paras):
    # parameters
    img_dir = paras.image_dir
    anno_path = paras.anno_path
    checkpoint_dir = paras.model_save_dir
    val_percent = 0.1
    epochs = paras.epochs
    batch_size = paras.batch_size
    lr = paras.learning_rate
    num_workers = 2

    # torch model saver
    saver = ModelSaver(max_save_num=5)

    # load dataset info
    dataset = load_dataset_info(img_dir, anno_path)
    train_set_info, valid_set_info = split_dataset_info(dataset, val_percent)

    # build dataloader
    building_trainset = Building_Dataset(train_set_info)
    building_validset = Building_Dataset(valid_set_info)
    train_dataloader = torch.utils.data.DataLoader(building_trainset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers)
    valid_dataloader = torch.utils.data.DataLoader(building_validset,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   num_workers=num_workers)

    # optimizer
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)

    # loss function
    #criterion = nn.L1Loss(reduce=True, size_average=True)
    criterion = nn.BCELoss()

    train_num = len(building_trainset)
    valid_num = len(building_validset)
    print('''
    Starting training:
        Total Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints save dir: {}
    '''.format(epochs, batch_size, lr, train_num, valid_num, checkpoint_dir))

    # ------------------------
    # start training...
    # ------------------------
    best_valid_loss = 1000
    for epoch in range(1, epochs + 1):
        print('Starting epoch {}/{}.'.format(epoch, epochs))

        # training
        net.train()
        epoch_loss = 0
        for idx, data in enumerate(train_dataloader):
            imgs, true_masks = data

            imgs = imgs.cuda()
            true_masks = true_masks.cuda()

            pred_masks = net(imgs)

            # compute loss
            loss = criterion(pred_masks, true_masks)
            epoch_loss += loss.item()

            if idx % 10 == 0:
                print(f'{idx}/{len(train_dataloader)}, loss: {loss.item()}')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        epoch_loss = epoch_loss / len(train_dataloader)
        print('Epoch finished ! Loss: {}\n'.format(epoch_loss))

        # validation
        net.eval()
        valid_loss = 0
        with torch.no_grad():
            for idx, data in enumerate(valid_dataloader):
                if idx % 10 == 0:
                    print(idx, '/', len(valid_dataloader))
                imgs, true_masks = data

                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

                # inference
                pred_masks = net(imgs)
                # compute loss
                loss = criterion(pred_masks, true_masks)
                valid_loss += loss.item()
        valid_loss = valid_loss / len(valid_dataloader)
        print('Validation finished ! Loss:{}  Best Loss before:{}\n'.format(
            valid_loss, best_valid_loss))

        # save check_point
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            print('New best model find, Checkpoint {} saving...'.format(epoch))
            model_save_path = os.path.join(
                checkpoint_dir, '{}_CP{}.pth'.format(best_valid_loss, epoch))
            #torch.save(net.state_dict(), model_save_path)
            saver.save_new_model(net, model_save_path)
示例#7
0
def main(params):
    logger = logging.getLogger(params['alias'])
    gpu_id = set_device(logger, params['gpu_id'])
    logger = logging.getLogger(params['alias'] + '(%d)' % gpu_id)
    set_device(logger, params['gpu_id'])
    saver = ModelSaver(params, os.path.abspath('./third_party/densevid_eval'))
    model_sl, model_cg = construct_model(params, saver, logger)

    model_sl, model_cg = model_sl.cuda(), model_cg.cuda()

    training_set = ANetDataSample(params['train_data'], params['feature_path'],
                                  params['translator_path'],
                                  params['video_sample_rate'], logger)
    val_set = ANetDataFull(params['val_data'], params['feature_path'],
                           params['translator_path'],
                           params['video_sample_rate'], logger)
    train_loader_cg = DataLoader(training_set,
                                 batch_size=params['batch_size'],
                                 shuffle=True,
                                 num_workers=params['num_workers'],
                                 collate_fn=collate_fn,
                                 drop_last=True)
    train_loader_sl = DataLoader(training_set,
                                 batch_size=6,
                                 shuffle=True,
                                 num_workers=params['num_workers'],
                                 collate_fn=collate_fn,
                                 drop_last=True)
    val_loader = DataLoader(val_set,
                            batch_size=8,
                            shuffle=True,
                            num_workers=params['num_workers'],
                            collate_fn=collate_fn,
                            drop_last=True)

    optimizer_sl = torch.optim.SGD(model_sl.get_parameter_group(params),
                                   lr=params['lr'],
                                   weight_decay=params['weight_decay'],
                                   momentum=params['momentum'])

    optimizer_cg = torch.optim.SGD(list(
        chain(model_sl.get_parameter_group_c(params),
              model_cg.get_parameter_group(params))),
                                   lr=params['lr'],
                                   weight_decay=params['weight_decay'],
                                   momentum=params['momentum'])

    optimizer_cg_n = torch.optim.SGD(model_cg.get_parameter_group(params),
                                     lr=params['lr'],
                                     weight_decay=params['weight_decay'],
                                     momentum=params['momentum'])

    lr_scheduler_sl = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_sl,
        milestones=params['lr_step'],
        gamma=params["lr_decay_rate"])

    lr_scheduler_cg = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_cg,
        milestones=params['lr_step'],
        gamma=params["lr_decay_rate"])

    evaluator = CaptionEvaluator(training_set)
    saver.save_model(
        model_sl, 0, {
            'step': 0,
            'model_sl': model_sl.state_dict(),
            'model_cg': model_cg.state_dict()
        })
    # eval(model_sl, model_cg, val_loader, logger, saver, params, -1)
    for step in range(params['training_epoch']):
        lr_scheduler_cg.step()
        lr_scheduler_sl.step()

        if step < params['pretrain_epoch']:
            pretrain_cg(model_cg, train_loader_cg, params, logger, step,
                        optimizer_cg_n)
        elif step % params['alter_step'] != 0:
            train_cg(model_cg, model_sl, train_loader_cg, params, logger, step,
                     optimizer_cg)
        else:
            train_sl(model_cg, model_sl, train_loader_sl, evaluator, params,
                     logger, step, optimizer_sl)

        # validation and saving
        # if step % params['test_interval'] == 0:
        #     eval(model_sl, model_cg, val_loader, logger, saver, params, step)
        if step % params['save_model_interval'] == 0 and step != 0:
            saver.save_model(
                model_sl, step, {
                    'step': step,
                    'model_sl': model_sl.state_dict(),
                    'model_cg': model_cg.state_dict()
                })
示例#8
0
文件: train.py 项目: tengyiyang/LIReC
def training(train_dataset, **kwargs):
    train_start_time = datetime.now().strftime('%Y%m%d-%H%M%S')
    print('set parameters and model, train start time: %s' % train_start_time)

    model = kwargs['model']
    loss = kwargs['loss']
    optimizer = kwargs['optimizer']

    batch_time = Averaging()
    data_time = Averaging()
    losses = Averaging()

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=opt.num_workers,
                                                   drop_last=False)

    print('epochs: %s', opt.epochs)
    model_saver_val = ModelSaver(path=opt.store_root)
    for epoch in range(opt.epochs):
        model.to(opt.device)
        model.train()
        train_dataset.epoch = epoch

        print('Epoch # %d' % epoch)
        end = time.time()
        counter = 0
        if opt.tr_sum_max:
            if epoch == 20:
                opt.tr_sum_max_flag = True
        for i, input in enumerate(train_dataloader):
            data_time.update(time.time() - end)
            labels = input['labels']
            if len(labels) == 1:
                continue
            output = model(input)
            loss_values = loss(output, input)
            losses.update(loss_values.item(), len(labels))

            optimizer.zero_grad()
            loss_values.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            counter += len(labels)

            if i % 10 == 0 and i:
                print(
                    'Epoch: [{0}][{1}/{2}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\tLoss {loss.val:.4f} ({loss.avg:.4f})\t'
                    .format(epoch,
                            i,
                            len(train_dataloader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses))
        print(counter)
        print('loss: %f' % losses.avg)
        losses.reset()
        if epoch % opt.test_fr == 0:
            testing(train_dataset,
                    model,
                    loss,
                    total_iter=epoch,
                    mode='train',
                    train_start_time=train_start_time)
            if opt.test:
                check_val = testing(kwargs['val_dataset'],
                                    model,
                                    loss,
                                    total_iter=epoch,
                                    train_start_time=train_start_time,
                                    mode='val')
                if model_saver_val.check(check_val):
                    save_dict = {
                        'epoch': epoch,
                        'state_dict': copy.deepcopy(model.state_dict()),
                        'optimizer':
                        copy.deepcopy(optimizer.state_dict().copy())
                    }
                    model_saver_val.update(check_val, save_dict, epoch)

                    testing(kwargs['test_dataset'],
                            model,
                            loss,
                            total_iter=epoch,
                            train_start_time=train_start_time,
                            mode='test')

            print(opt.log_prefix)

        if opt.save_model and opt.save_model_often and epoch % 30 == 0:
            model_saver_val.save()

    check_str = join(opt.store_root)
    opt.resume_str = join(check_str, '%d.pth.tar' % epoch)
    if opt.save_model:
        save_dict = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        dir_check(check_str)
        torch.save(save_dict, opt.resume_str)
    return model
示例#9
0
def train(main_config, model_config, model_name, dataset_name):
    main_cfg = MainConfig(main_config)
    model = MODELS[model_name]
    dataset = DATASETS[dataset_name]()

    model_name = '{}_{}'.format(model_name,
                                main_config['PARAMS']['embedding_size'])

    train_data = dataset.train_set_pairs()
    vectorizer = DatasetVectorizer(train_data, main_cfg.model_dir)

    dataset_helper = Dataset(vectorizer, dataset, main_cfg.batch_size)
    max_sentence_len = vectorizer.max_sentence_len
    vocabulary_size = vectorizer.vocabulary_size

    train_mini_sen1, train_mini_sen2, train_mini_labels = dataset_helper.pick_train_mini_batch(
    )
    train_mini_labels = train_mini_labels.reshape(-1, 1)

    test_sentence1, test_sentence2 = dataset_helper.test_instances()
    test_labels = dataset_helper.test_labels()
    test_labels = test_labels.reshape(-1, 1)

    num_batches = dataset_helper.num_batches
    model = model(max_sentence_len, vocabulary_size, main_config, model_config)
    model_saver = ModelSaver(main_cfg.model_dir, model_name,
                             main_cfg.checkpoints_to_keep)
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=main_cfg.log_device_placement)

    with tf.Session(config=config) as session:
        global_step = 0
        init = tf.global_variables_initializer()
        session.run(init)
        log_saver = LogSaver(main_cfg.logs_path, model_name, dataset_name,
                             session.graph)
        model_evaluator = ModelEvaluator(model, session)

        metrics = {'acc': 0.0}
        time_per_epoch = []
        for epoch in tqdm(range(main_cfg.num_epochs), desc='Epochs'):
            start_time = time.time()

            train_sentence1, train_sentence2 = dataset_helper.train_instances(
                shuffle=True)
            train_labels = dataset_helper.train_labels()

            train_batch_helper = BatchHelper(train_sentence1, train_sentence2,
                                             train_labels, main_cfg.batch_size)

            # small eval set for measuring dev accuracy
            dev_sentence1, dev_sentence2, dev_labels = dataset_helper.dev_instances(
            )
            dev_labels = dev_labels.reshape(-1, 1)
            tqdm_iter = tqdm(range(num_batches),
                             total=num_batches,
                             desc="Batches",
                             leave=False,
                             postfix=metrics)
            for batch in tqdm_iter:
                global_step += 1
                sentence1_batch, sentence2_batch, labels_batch = train_batch_helper.next(
                    batch)
                feed_dict_train = {
                    model.x1: sentence1_batch,
                    model.x2: sentence2_batch,
                    model.is_training: True,
                    model.labels: labels_batch
                }

                loss, _ = session.run([model.loss, model.opt],
                                      feed_dict=feed_dict_train)

                if batch % main_cfg.eval_every == 0:
                    feed_dict_train = {
                        model.x1: train_mini_sen1,
                        model.x2: train_mini_sen2,
                        model.is_training: False,
                        model.labels: train_mini_labels
                    }

                    train_accuracy, train_summary = session.run(
                        [model.accuracy, model.summary_op],
                        feed_dict=feed_dict_train)
                    log_saver.log_train(train_summary, global_step)

                    feed_dict_dev = {
                        model.x1: dev_sentence1,
                        model.x2: dev_sentence2,
                        model.is_training: False,
                        model.labels: dev_labels
                    }

                    dev_accuracy, dev_summary = session.run(
                        [model.accuracy, model.summary_op],
                        feed_dict=feed_dict_dev)
                    log_saver.log_dev(dev_summary, global_step)
                    tqdm_iter.set_postfix(
                        dev_acc='{:.2f}'.format(float(dev_accuracy)),
                        train_acc='{:.2f}'.format(float(train_accuracy)),
                        loss='{:.2f}'.format(float(loss)),
                        epoch=epoch)

                if global_step % main_cfg.save_every == 0:
                    model_saver.save(session, global_step=global_step)

            model_evaluator.evaluate_dev(dev_sentence1, dev_sentence2,
                                         dev_labels)

            end_time = time.time()
            total_time = timer(start_time, end_time)
            time_per_epoch.append(total_time)

            model_saver.save(session, global_step=global_step)

        feed_dict_train = {
            model.x1: test_sentence1,
            model.x2: test_sentence2,
            model.is_training: False,
            model.labels: test_labels
        }

        #train_accuracy, train_summary, train_e = session.run([model.accuracy, model.summary_op, model.e],
        #                                            feed_dict=feed_dict_train)

        train_e = session.run([model.e], feed_dict=feed_dict_train)
        plt.clf()
        f = plt.figure(figsize=(8, 8.5))
        ax = f.add_subplot(1, 1, 1)

        i = ax.imshow(train_e[0][0], interpolation='nearest', cmap='gray')

        cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
        cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
        cbar.ax.set_xlabel('Probability', labelpad=2)

        f.savefig('attention_maps.pdf', bbox_inches='tight')
        f.show()
        plt.show()

        feed_dict_test = {
            model.x1: test_sentence1,
            model.x2: test_sentence2,
            model.is_training: False,
            model.labels: test_labels
        }

        test_accuracy, test_summary = session.run(
            [model.accuracy, model.summary_op], feed_dict=feed_dict_test)
        print('tst_acc:%.2f loss:%.2f', test_accuracy, loss)

        model_evaluator.evaluate_test(test_sentence1, test_sentence2,
                                      test_labels)
        model_evaluator.save_evaluation(
            '{}/{}'.format(main_cfg.model_dir, model_name), time_per_epoch[-1],
            dataset)