Esempio n. 1
0
    def train_one_epoch(self, epoch, accum_iter):
        self.model.train()
        self.lr_scheduler.step()

        average_meter_set = AverageMeterSet()
        tqdm_dataloader = tqdm(self.train_loader)

        for batch_idx, batch in enumerate(tqdm_dataloader):
            batch_size = batch[0].size(0)
            batch = [x.to(self.device) for x in batch]

            self.optimizer.zero_grad()
            loss = self.calculate_loss(batch)
            loss.backward()

            self.optimizer.step()

            average_meter_set.update('loss', loss.item())
            tqdm_dataloader.set_description(
                'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))

            accum_iter += batch_size

            if self._needs_to_log(accum_iter):
                tqdm_dataloader.set_description('Logging to Tensorboard')
                log_data = {
                    'state_dict': (self._create_state_dict()),
                    'epoch': epoch,
                    'accum_iter': accum_iter,
                }
                log_data.update(average_meter_set.averages())
                self.log_extra_train_info(log_data)
                self.logger_service.log_train(log_data)

        return accum_iter
Esempio n. 2
0
    def eval_one_epoch(self, eval_loader, epoch=None):

        average_meter_set = AverageMeterSet()

        with torch.no_grad():
            tqdm_dataloader = tqdm(eval_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = self.batch_to_device(batch)

                metrics = self.calculate_metrics(batch)

                for k, v in metrics.items():
                    average_meter_set.update(k, v)

                if self.args.local and batch_idx > 20:
                    break

                if batch_idx % 10 == 0 and batch_idx > 0:
                    descr = get_metric_descr(average_meter_set, self.metric_ks)
                    tqdm_dataloader.set_description(descr)

        descr = get_metric_descr(average_meter_set, self.metric_ks)
        if epoch is not None:
            print("\n Epoch {} avg.: {}".format(epoch+1, descr))
        else:
            print("\n")
        #tqdm_dataloader.set_description(descr)

        return average_meter_set
Esempio n. 3
0
    def test(self):
        print('Test best model with test set!')

        best_model = torch.load(
            os.path.join(self.export_root, 'models',
                         'best_acc_model.pth')).get('model_state_dict')
        self.model.load_state_dict(best_model)
        self.model.eval()
        average_meter_set = AverageMeterSet()

        with torch.no_grad():
            tqdm_dataloader = tqdm(self.test_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = [x.to(self.device) for x in batch]

                metrics, preds = self.calculate_metrics(batch)

                for k, v in metrics.items():
                    average_meter_set.update(k, v)
                description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\
                                      ['Recall@%d' % k for k in self.metric_ks[:3]]
                description = 'Val: ' + ', '.join(s + ' {:.3f}'
                                                  for s in description_metrics)
                description = description.replace('NDCG',
                                                  'N').replace('Recall', 'R')
                description = description.format(
                    *(average_meter_set[k].avg for k in description_metrics))
                tqdm_dataloader.set_description(description)

            average_metrics = average_meter_set.averages()
            with open(
                    os.path.join(self.export_root, 'logs',
                                 'test_metrics.json'), 'w') as f:
                json.dump(average_metrics, f, indent=4)
            print(average_metrics)
Esempio n. 4
0
    def validate(self, epoch, accum_iter):
        self.model.eval()

        average_meter_set = AverageMeterSet()

        with torch.no_grad():
            tqdm_dataloader = tqdm(self.val_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = [x.to(self.device) for x in batch]

                metrics = self.calculate_metrics(batch)

                for k, v in metrics.items():
                    average_meter_set.update(k, v)
                description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\
                                      ['Recall@%d' % k for k in self.metric_ks[:3]]
                description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics)
                description = description.replace('NDCG', 'N').replace('Recall', 'R')
                description = description.format(*(average_meter_set[k].avg for k in description_metrics))
                tqdm_dataloader.set_description(description)

            log_data = {
                'state_dict': (self._create_state_dict()),
                'epoch': epoch,
                'accum_iter': accum_iter,
            }
            log_data.update(average_meter_set.averages())
            self.logger_service.log_val(log_data)
Esempio n. 5
0
    def test(self):
        self.model.eval()

        average_meter_set = AverageMeterSet()

        with torch.no_grad():
            tqdm_dataloader = tqdm(self.test_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = [x.to(self.device) for x in batch]

                metrics = self.calculate_metrics(batch)

                for k, v in metrics.items():
                    average_meter_set.update(k, v)
                description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\
                                      ['Recall@%d' % k for k in self.metric_ks[:3]]
                description = 'Test: ' + ', '.join(
                    s + ' {:.3f}' for s in description_metrics)
                description = description.replace('NDCG',
                                                  'N').replace('Recall', 'R')
                description = description.format(
                    *(average_meter_set[k].avg for k in description_metrics))
                tqdm_dataloader.set_description(description)
        return {
            'dataset': self.dataset_code,
            'pruning_code': self.prune_code,
            'pruning_perc': self.pruning_perc,
            'pruning_perc_embed': self.pruning_perc_embed,
            'pruning_perc_feed': self.pruning_perc_feed,
            'pruning_epochs': self.num_prune_epochs,
            'num_epochs': self.num_epochs,
            'result': description,
        }
def validate_voc_file(eval_loader, model, thre, epoch, print_freq, results_dir):
    start_time = time.time()

    meters = AverageMeterSet()

    model.eval()
    Sig = torch.nn.Sigmoid()

    end = time.time()
    preds = []
    targets = []
    names = []
    for i, data in enumerate(eval_loader):
        assert len(data) >= 4
        input, target, name = data[0], data[1], data[3]
        meters.update('data_time', time.time() - end)
        # compute output
        with torch.no_grad():
            output = Sig(model(input.cuda()))

        # for mAP calculation
        preds.append(output.cpu())
        targets.append(target.cpu())
        names.extend(name)

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {meters[batch_time]:.3f}\t'
                  'Data {meters[data_time]:.3f}\t'
                  .format(i, len(eval_loader), meters=meters))

    preds = torch.cat(preds).numpy()
    targs = torch.cat(targets).numpy()
    if results_dir is not None:
        # save to results dir
        os.makedirs(results_dir, exist_ok=True)
        for i in range(20):
            cls_name = eval_loader.dataset.class_list[i]
            filename = '{}_{}.txt'.format(cls_name, eval_loader.dataset.image_set)
            with open(os.path.join(results_dir, filename), 'w') as f:
                for j in range(len(names)):
                    f.write('{} {}\n'.format(names[j], preds[j, i]))


    AP = eval_loader.dataset.eval_file(results_dir)
    eval_loader.dataset.show_AP(AP, print_func=LOG.info)

    mAP = 100 * AP.mean()
    print(" * TEST [{}] VOC2012 mAP: {}".format(epoch, mAP))

    print("--- testing epoch in {} seconds ---".format(time.time() - start_time))
    return mAP
Esempio n. 7
0
    def validate(self, epoch, accum_iter, mode, doLog=True, **kwargs):
        if mode == 'val':
            loader = self.val_loader
        elif mode == 'test':
            loader = self.test_loader
        else:
            raise ValueError

        self.model.eval()

        average_meter_set = AverageMeterSet()
        num_instance = 0

        with torch.no_grad():
            tqdm_dataloader = tqdm(loader) if not self.pilot else loader
            for batch_idx, batch in enumerate(tqdm_dataloader):
                if self.pilot and batch_idx >= self.pilot_batch_cnt:
                    # print('Break validation due to pilot mode')
                    break
                batch = {k: v.to(self.device) for k, v in batch.items()}
                batch_size = next(iter(batch.values())).size(0)
                num_instance += batch_size

                metrics = self.calculate_metrics(batch)

                for k, v in metrics.items():
                    average_meter_set.update(k, v)
                if not self.pilot:
                    description_metrics = ['NDCG@%d' % k for k in self.metric_ks] +\
                           ['Recall@%d' % k for k in self.metric_ks]
                    description = '{}: '.format(mode.capitalize()) + ', '.join(
                        s + ' {:.4f}' for s in description_metrics)
                    description = description.replace('NDCG', 'N').replace(
                        'Recall', 'R')
                    description = description.format(
                        *(average_meter_set[k].avg
                          for k in description_metrics))
                    tqdm_dataloader.set_description(description)

            log_data = {
                'state_dict': (self._create_state_dict(epoch, accum_iter)),
                'epoch': epoch,
                'accum_iter': accum_iter,
                'num_eval_instance': num_instance,
            }
            log_data.update(average_meter_set.averages())
            log_data.update(kwargs)
            if doLog:
                if mode == 'val':
                    self.logger_service.log_val(log_data)
                elif mode == 'test':
                    self.logger_service.log_test(log_data)
                else:
                    raise ValueError
        return log_data
Esempio n. 8
0
    def validate(self, epoch, accum_iter):
        self.model.eval()
        self.all_preds = []
        average_meter_set = AverageMeterSet()
        with torch.no_grad():
            tqdm_dataloader = tqdm(self.val_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = [x.to(self.device) for x in batch]
                metrics, preds = self.calculate_metrics(batch)
                for p in preds:
                    self.all_preds.append(p.tolist())
            for k, v in metrics.items():
                average_meter_set.update(k, v)
                description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] + \
                                      ['Recall@%d' % k for k in self.metric_ks[:3]]
                description = 'Val: ' + ', '.join(s + ' {:.3f}'
                                                  for s in description_metrics)
                description = description.replace('NDCG',
                                                  'N').replace('Recall', 'R')
                description = description.format(
                    *(average_meter_set[k].avg for k in description_metrics))
                tqdm_dataloader.set_description(description)

            log_data = {
                'state_dict': (self._create_state_dict()),
                'epoch': epoch + 1,
                'accum_iter': accum_iter,
            }
            log_data.update(average_meter_set.averages())
            self.log_extra_val_info(log_data)
            self.logger_service.log_val(log_data)

        df = pd.DataFrame(self.all_preds,
                          columns=[
                              'prediction_' + str(i)
                              for i in range(len(self.all_preds[0]))
                          ])
        if not os.path.isdir(self.args.output_predictions_folder):
            os.makedirs(self.args.output_predictions_folder)

        with open(
                os.path.join(self.args.output_predictions_folder,
                             'config.json'), 'w') as f:
            self.args.recommender = "BERT4rec"
            self.args.seed = str(self.args.model_init_seed)
            args_dict = {}
            args_dict['args'] = vars(self.args)

            f.write(json.dumps(args_dict, indent=4, sort_keys=True))
        df.to_csv(self.args.output_predictions_folder + "/predictions.csv",
                  index=False)
Esempio n. 9
0
    def train_one_epoch(self, epoch, accum_iter):
        self.model.train()

        average_meter_set = AverageMeterSet()
        tqdm_dataloader = tqdm(self.train_loader)

        for batch_idx, batch in enumerate(tqdm_dataloader):

            batch = self.batch_to_device(batch)
            batch_size = self.args.train_batch_size

            # forward pass
            self.optimizer.zero_grad()
            loss = self.calculate_loss(batch)

            # backward pass
            loss.backward()
            self.optimizer.step()

            # update metrics
            average_meter_set.update('loss', loss.item())
            average_meter_set.update('lr', self.optimizer.defaults['lr'])

            tqdm_dataloader.set_description('Epoch {}, loss {:.3f} '.format(epoch + 1, average_meter_set['loss'].avg))
            accum_iter += batch_size

            if self._needs_to_log(accum_iter):
                tqdm_dataloader.set_description('Logging to Tensorboard')
                log_data = {
                    'state_dict': (self._create_state_dict()),
                    'epoch': epoch+1,
                    'accum_iter': accum_iter,
                }
                log_data.update(average_meter_set.averages())
                self.log_extra_train_info(log_data)
                self.logger_service.log_train(log_data)

            if self.args.local and batch_idx == 20:
                break

        # adapt learning rate
        if self.args.enable_lr_schedule:
            self.lr_scheduler.step()
            if epoch % self.lr_scheduler.step_size == 0:
                print(self.optimizer.defaults['lr'])


        return accum_iter
def validate(eval_loader, model, epoch, print_freq, type_string=''):
    start_time = time.time()
    class_criterion = nn.CrossEntropyLoss().cuda()

    meters = AverageMeterSet()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, data in enumerate(eval_loader):
        input, target = data[0], data[1]
        meters.update('data_time', time.time() - end)

        with torch.no_grad():
            input = input.cuda()
            target = target.cuda()

            # compute output
            model_out = model(input)
            if isinstance(model_out, tuple):
                feat, class_logit = model_out
            else:
                class_logit = model_out

            class_loss = class_criterion(class_logit, target)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(class_logit.data, target.data, topk=(1, 5))
        minibatch_size = len(target)
        meters.update('class_loss', class_loss.item(), minibatch_size)
        meters.update('top1', prec1, minibatch_size)
        meters.update('top5', prec5, minibatch_size)

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print(
                'Test: [{0}/{1}]\t'
                'Time {meters[batch_time]:.3f}\t'
                'Data {meters[data_time]:.3f}\t'
                'Class {meters[class_loss]:.4f}\t'
                'Prec@1 {meters[top1]:.3f}\t'
                'Prec@5 {meters[top5]:.3f}'
                .format(i, len(eval_loader), meters=meters))

    print(' * Prec@1 {top1.avg:.3f}\tPrec@5 {top5.avg:.3f}'
          .format(top1=meters['top1'], top5=meters['top5']))
    print("--- testing epoch in {} seconds ---".format(time.time() - start_time))
    return meters['top1'].avg
def validate_voc(eval_loader, model, thre, epoch, print_freq):
    start_time = time.time()

    meters = AverageMeterSet()

    model.eval()
    Sig = torch.nn.Sigmoid()

    end = time.time()
    preds = []
    targets = []
    for i, data in enumerate(eval_loader):
        input, target = data[0], data[1]
        meters.update('data_time', time.time() - end)
        # compute output
        with torch.no_grad():
            output = Sig(model(input.cuda())).cpu()

        # for mAP calculation
        preds.append(output.cpu())
        targets.append(target.cpu())

        # measure accuracy and record loss
        this_prec, this_rec = prec_recall_for_batch(output.data, target, thre)
        meters.update('prec', float(this_prec), input.size(0))
        meters.update('rec', float(this_rec), input.size(0))

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {meters[batch_time]:.3f}\t'
                  'Data {meters[data_time]:.3f}\t'
                  'Prec {meters[prec]:.2f}\t'
                  'Recall {meters[rec]:.2f}'
                  .format(i, len(eval_loader), meters=meters))

    targs = torch.cat(targets).numpy()
    preds = torch.cat(preds).numpy()
    AP = eval_loader.dataset.eval(preds, targs)
    eval_loader.dataset.show_AP(AP)

    mAP = 100 * AP.mean()
    print(" * TEST [{}] mAP: {}".format(epoch, mAP))

    print("--- testing epoch in {} seconds ---".format(time.time() - start_time))
    return mAP
Esempio n. 12
0
    def validate(self, epoch, accum_iter):
        self.model.eval()

        average_meter_set = AverageMeterSet()

        with torch.no_grad():
            tqdm_dataloader = tqdm(self.val_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = [x.to(self.device) for x in batch]

                metrics = self.calculate_metrics(batch)

                for k, v in metrics.items():
                    average_meter_set.update(k, v)
                description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] + \
                                      ['Recall@%d' % k for k in self.metric_ks[:3]]
                if 'accuracy' in self.args.metrics_to_log:
                    description_metrics = ['accuracy']
                description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics)
                description = description.replace('NDCG', 'N').replace('Recall', 'R')
                description = description.format(*(average_meter_set[k].avg for k in description_metrics))
                tqdm_dataloader.set_description(description)

            log_data = {
                'state_dict': (self._create_state_dict()),
                'epoch': epoch + 1,
                'accum_iter': accum_iter,
                'user_embedding': self.model.embedding.user.weight.cpu().detach().numpy()
                if self.args.dump_useritem_embeddings == 'True'
                   and self.model.embedding.user is not None
                else None,
                'item_embedding': self.model.embedding.token.weight.cpu().detach().numpy()
                if self.args.dump_useritem_embeddings == 'True'
                else None,
            }
            log_data.update(average_meter_set.averages())
            self.log_extra_val_info(log_data)
            self.logger_service.log_val(log_data)
Esempio n. 13
0
    def train_one_epoch(self, epoch, accum_iter, train_loader, **kwargs):
        self.model.train()

        average_meter_set = AverageMeterSet()
        num_instance = 0
        tqdm_dataloader = tqdm(
            train_loader) if not self.pilot else train_loader

        for batch_idx, batch in enumerate(tqdm_dataloader):
            if self.pilot and batch_idx >= self.pilot_batch_cnt:
                # print('Break training due to pilot mode')
                break
            batch_size = next(iter(batch.values())).size(0)
            batch = {k: v.to(self.device) for k, v in batch.items()}
            num_instance += batch_size

            if self.total_anneal_steps > 0:
                anneal = min(self.anneal_cap,
                             1. * self.update_count / self.total_anneal_steps)
            else:
                anneal = self.anneal_cap

            self.optimizer.zero_grad()
            loss = self.calculate_loss(batch, anneal)
            if isinstance(loss, tuple):
                loss, extra_info = loss
                for k, v in extra_info.items():
                    average_meter_set.update(k, v)
            loss.backward()

            if self.clip_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.clip_grad_norm)

            self.optimizer.step()

            self.update_count += 1

            average_meter_set.update('loss', loss.item())
            if not self.pilot:
                tqdm_dataloader.set_description(
                    'Epoch {}, loss {:.3f} '.format(
                        epoch, average_meter_set['loss'].avg))

            accum_iter += batch_size

            if self._needs_to_log(accum_iter):
                if not self.pilot:
                    tqdm_dataloader.set_description('Logging')
                log_data = {
                    # 'state_dict': (self._create_state_dict()),
                    'epoch': epoch,
                    'accum_iter': accum_iter,
                }
                log_data.update(average_meter_set.averages())
                log_data.update(kwargs)
                self.log_extra_train_info(log_data)
                self.logger_service.log_train(log_data)

        log_data = {
            # 'state_dict': (self._create_state_dict()),
            'epoch': epoch,
            'accum_iter': accum_iter,
            'num_train_instance': num_instance,
        }
        log_data.update(average_meter_set.averages())
        log_data.update(kwargs)
        self.log_extra_train_info(log_data)
        self.logger_service.log_train(log_data)
        return accum_iter
def train(train_loader, model, criterion, optimizer, epoch, args):
    start_time = time.time()
    meters = AverageMeterSet()

    Sig = torch.nn.Sigmoid()

    # switch to train mode
    """
        Switch to eval mode:
        Under the protocol of linear classification on frozen features/models,
        it is not legitimate to change any part of the pre-trained model.
        BatchNorm in train mode may revise running mean/std (even if it receives
        no gradient), which are part of the model parameters too.
    """
    if args.linear_eval:
        model.eval()
    # switch to train mode
    else:
        #model.train()
        model.eval()
    #model.train()

    end = time.time()
    for i, data in enumerate(train_loader):
        images, target = data[0], data[1]
        meters.update('data_time', time.time() - end)

        adjust_learning_rate(optimizer, epoch, i, len(train_loader), args)

        # measure data loading time
        #data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        if torch.cuda.is_available():
            target = target.float().cuda(args.gpu, non_blocking=True)

        # compute output
        model_output = model(images)
        if isinstance(model_output, tuple):
            feat, class_logit = model_output
        else:
            class_logit = model_output

        output = Sig(class_logit)

        loss = criterion(class_logit, target)
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        meters.update('lr', optimizer.param_groups[0]['lr'])
        meters.update('class_loss', loss.item())
        # measure accuracy and record loss
        this_prec, this_rec = prec_recall_for_batch(output.data, target,
                                                    args.thre)
        meters.update('prec', float(this_prec), images.size(0))
        meters.update('rec', float(this_rec), images.size(0))

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {meters[batch_time]:.3f}\t'
                  'Data {meters[data_time]:.3f}\t'
                  'Class {meters[class_loss]:.4f}\t'
                  'Prec {meters[prec]:.3f}\t'
                  'Rec {meters[rec]:.3f}\t'.format(epoch,
                                                   i,
                                                   len(train_loader),
                                                   meters=meters))

    print(' * TRAIN Prec {:.3f} ({:.1f}/{:.1f}) Recall {:.3f} ({:.1f}/{:.1f})'.
          format(meters['prec'].avg, meters['prec'].sum / 100,
                 meters['prec'].count, meters['rec'].avg,
                 meters['rec'].sum / 100, meters['rec'].count))
    print("--- training epoch in {} seconds ---".format(time.time() -
                                                        start_time))
Esempio n. 15
0
def train(train_loader, model, class_criterion, optimizer, epoch):
    global global_step
    start_time = time.time()

    Sig = torch.nn.Sigmoid()

    meters = AverageMeterSet()

    # switch to train mode
    model.train()

    end = time.time()
    for i, data in enumerate(train_loader):
        input, target = data[0], data[1]
        # measure data loading time
        meters.update('data_time', time.time() - end)

        adjust_learning_rate(optimizer, epoch, i, len(train_loader))

        input, target = input.cuda(), target.float().cuda()

        model_out = model(input)
        if isinstance(model_out, tuple):
            feat, class_logit = model_out
        else:
            class_logit = model_out

        output = Sig(class_logit)
        # output = class_logit
        class_loss = class_criterion(class_logit, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        class_loss.backward()
        # nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        optimizer.step()
        global_step += 1

        meters.update('lr', optimizer.param_groups[0]['lr'])
        minibatch_size = len(target)
        meters.update('class_loss', class_loss.item())
        # measure accuracy and record loss
        this_prec, this_rec = prec_recall_for_batch(output.data, target, thre)
        meters.update('prec', float(this_prec), input.size(0))
        meters.update('rec', float(this_rec), input.size(0))

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {meters[batch_time]:.3f}\t'
                  'Data {meters[data_time]:.3f}\t'
                  'Class {meters[class_loss]:.4f}\t'
                  'Prec {meters[prec]:.3f}\t'
                  'Rec {meters[rec]:.3f}\t'.format(epoch,
                                                   i,
                                                   len(train_loader),
                                                   meters=meters))

    print(' * TRAIN Prec {:.3f} ({:.1f}/{:.1f}) Recall {:.3f} ({:.1f}/{:.1f})'.
          format(meters['prec'].avg, meters['prec'].sum / 100,
                 meters['prec'].count, meters['rec'].avg,
                 meters['rec'].sum / 100, meters['rec'].count))

    print("--- training epoch in {} seconds ---".format(time.time() -
                                                        start_time))