Exemplo n.º 1
0
    def log_validation_results(engine):
        data_list = [train_loader, val_loader]
        name_list = ['train', 'val']
        eval_list = [train_evaluator, val_evaluator]

        for data, name, evl in zip(data_list, name_list, eval_list):
            evl.run(data)
            metrics_info = evl.state.metrics["multitask"]

            for m, val in metrics_info['metrics'].items():
                writer.add_scalar(name + '_metrics/{}'.format(m), val,
                                  engine.state.epoch)

            for m, val in metrics_info['summaries'].items():
                writer.add_scalar(name + '_summary/{}'.format(m), val,
                                  engine.state.epoch)

            logger(
                name +
                ": Validation Results - Epoch: {}".format(engine.state.epoch))
            print_summar_table(logger, attr_name, metrics_info['logger'])

            # Update Learning Rate
            if name == 'train':
                scheduler.step(metrics_info['logger']['attr']['ap'][-1])
    def save(self, epoch, metrics_info):
        ap = metrics_info['logger']['attr']['ap'][-1]
        if ap > self.max_ap:
            self.max_ap = ap
            save_file_path = os.path.join(self.save_root, 'ap{}'.format(ap))
            torch.save(s_net.module.state_dict(), save_file_path)

            logger_file("val: Validation Results - Epoch: {} - LR: {}".format(epoch, optimizer.optimizer.param_groups[0]['lr']))
            print_summar_table(logger_file, attr_name, metrics_info['logger'])
            logger_file('AP:%0.3f' % metrics_info['logger']['attr']['ap'][-1])
def test(net, epoch):
    net.eval()
    data_list = [trainloader, testloader]
    name_list = ['train', 'val']
    eval_list = [train_evaluator, val_evaluator]

    for data, name, evl in zip(data_list, name_list, eval_list):
        evl.run(data)
        metrics_info = evl.state.metrics["multitask"]
        logger(name + ": Validation Results - Epoch: {}".format(epoch))
        print_summar_table(logger, attr_name, metrics_info['logger'])

    if args.scheduler == 'pleau':
        optimizer.step(metrics_info['logger']['attr']['ap'][-1])
    else:
        optimizer.step()
    return metrics_info