コード例 #1
0
    def test_log_summary(self, mock_board, _):
        mock_board.return_value = Mock()
        mock_board.return_value.add_text = Mock()
        mock_self = 'test'

        state = {torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)),
                 torchbearer.EPOCH: 0, torchbearer.METRICS: {'test': 1}, torchbearer.BATCH: 0, torchbearer.SELF: mock_self}
        tboard = TensorBoardText(write_batch_metrics=False, write_epoch_metrics=False, log_trial_summary=True)
        tboard.on_start(state)
        self.assertEqual(mock_board.return_value.add_text.call_args[0][0], 'trial')
        self.assertEqual(mock_board.return_value.add_text.call_args[0][1], str(mock_self))
コード例 #2
0
    def test_batch_metrics_visdom(self, mock_visdom, mock_writer, _):
        mock_writer.return_value = Mock()
        mock_writer.return_value.add_text = Mock()

        state = {
            torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)),
            torchbearer.EPOCH: 0,
            torchbearer.METRICS: {
                'test': 1
            },
            torchbearer.BATCH: 0,
            torchbearer.TRAIN_STEPS: 0
        }

        tboard = TensorBoardText(visdom=True,
                                 write_batch_metrics=True,
                                 write_epoch_metrics=False,
                                 log_trial_summary=False)
        tboard.on_start(state)
        tboard.on_start_epoch(state)
        tboard.on_step_training(state)
        mock_writer.return_value.add_text.assert_called_once_with(
            'batch', '<h3>Epoch {} - Batch {}</h3>'.format(
                state[torchbearer.EPOCH], state[torchbearer.BATCH]) +
            TensorBoardText.table_formatter(str(state[torchbearer.METRICS])),
            1)
        mock_writer.return_value.add_text.reset_mock()
        tboard.on_step_validation(state)
        tboard.on_end(state)
コード例 #3
0
    def test_batch_metrics(self, mock_board, _):
        mock_board.return_value = Mock()
        mock_board.return_value.add_text = Mock()

        state = {
            torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)),
            torchbearer.EPOCH: 0,
            torchbearer.METRICS: {
                'test': 1
            },
            torchbearer.BATCH: 0
        }

        tboard = TensorBoardText(write_batch_metrics=True,
                                 write_epoch_metrics=False,
                                 log_trial_summary=False)
        tboard.on_start(state)
        tboard.on_start_epoch(state)
        tboard.on_step_training(state)
        mock_board.return_value.add_text.assert_called_once_with(
            'batch',
            TensorBoardText.table_formatter(str(state[torchbearer.METRICS])),
            0)
        mock_board.return_value.add_text.reset_mock()
        tboard.on_end_epoch(state)
        tboard.on_end(state)
コード例 #4
0
    def test_batch_writer_visdom(self, mock_visdom, mock_writer, _):
        tboard = TensorBoardText(visdom=True,
                                 write_epoch_metrics=False,
                                 write_batch_metrics=True,
                                 log_trial_summary=False)

        metrics = {'test_metric_1': 1, 'test_metric_2': 1}
        state = {
            torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)),
            torchbearer.EPOCH: 1,
            torchbearer.BATCH: 100,
            torchbearer.METRICS: metrics
        }
        metric_string = TensorBoardText.table_formatter(str(metrics))
        metric_string = '<h3>Epoch {} - Batch {}</h3>'.format(
            state[torchbearer.EPOCH], state[torchbearer.BATCH]) + metric_string

        tboard.on_start(state)
        tboard.on_start_training(state)
        tboard.on_start_epoch(state)
        tboard.on_step_training(state)
        mock_writer.return_value.add_text.assert_called_once_with(
            'batch', metric_string, 1)
        tboard.on_end_epoch(state)
        tboard.on_end(state)
コード例 #5
0
    def test_epoch_writer(self, mock_writer, _):
        tboard = TensorBoardText(log_trial_summary=False)

        metrics = {'test_metric_1': 1, 'test_metric_2': 1}
        state = {
            torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)),
            torchbearer.EPOCH: 1,
            torchbearer.METRICS: metrics
        }
        metric_string = TensorBoardText.table_formatter(str(metrics))

        tboard.on_start(state)
        tboard.on_start_training(state)
        tboard.on_start_epoch(state)
        tboard.on_end_epoch(state)
        mock_writer.return_value.add_text.assert_called_once_with(
            'epoch', metric_string, 1)
        tboard.on_end(state)
コード例 #6
0
ファイル: trainer.py プロジェクト: Kirito-520/AI-competition
valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) if (valset is not None) and (args.dataset not in nlp_data) else valset
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) if (args.dataset not in nlp_data) else testset

print('==> Building model..')
net = get_model(args, classes, nc)
net = nn.DataParallel(net) if args.parallel else net
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)

if (args.dataset in nlp_data) or ('modelnet' in args.dataset):
    optimizer = optim.Adam(net.parameters(), lr=args.lr)


print('==> Setting up callbacks..')
current_time = datetime.now().strftime('%b%d_%H-%M-%S') + "-run-" + str(args.run_id)
tboard = TensorBoard(write_graph=False, comment=current_time, log_dir=args.log_dir)
tboardtext = TensorBoardText(write_epoch_metrics=False, comment=current_time, log_dir=args.log_dir)


@torchbearer.callbacks.on_start
def write_params(_):
    params = vars(args)
    params['schedule'] = str(params['schedule'])
    df = pd.DataFrame(params, index=[0]).transpose()
    tboardtext.get_writer(tboardtext.log_dir).add_text('params', df.to_html(), 1)


modes = {
    'fmix': FMix(decay_power=args.f_decay, alpha=args.alpha, size=size, max_soft=0, reformulate=args.reformulate),
    'mixup': RMixup(args.alpha, reformulate=args.reformulate),
    'cutmix': CutMix(args.alpha, classes, True),
    'pointcloud_fmix': PointNetFMix(args.pointcloud_resolution, decay_power=args.f_decay, alpha=args.alpha, max_soft=0,