Exemple #1
0
def torch_summarize(model, show_weights=True, show_parameters=True):
    """Summarizes torch model by showing trainable parameters and weights."""
    tmpstr = model.__class__.__name__ + ' (\n'
    for key, module in model._modules.items():
        # if it contains layers let call it recursively to get params and weights
        if type(module) in [
                torch.nn.modules.container.Container,
                torch.nn.modules.container.Sequential
        ]:
            modstr = torch_summarize(module)
        else:
            modstr = module.__repr__()
        modstr = _addindent(modstr, 2)

        params = sum([np.prod(p.size()) for p in module.parameters()])
        weights = tuple([tuple(p.size()) for p in module.parameters()])

        tmpstr += '  (' + key + '): ' + modstr
        if show_weights:
            tmpstr += ', weights={}'.format(weights)
        if show_parameters:
            tmpstr += ', parameters={}'.format(params)
        tmpstr += '\n'

    tmpstr = tmpstr + ')'
    return tmpstr
Exemple #2
0
        batch_size = params['batch_size']
        n_classes = 10
        input_channels = 3 * 32
        input_channels = 1
        seq_length = data['Frame_len'] * 3 * 32
        total_step = len(training_set)

        learning_rate = opt.force_learning_rate

        #model = simpleLSTM(input_size, hidden_size, num_layers, output_size)
        model = TCN(input_channels,
                    n_classes, [25] * 8,
                    kernel_size=7,
                    dropout=0.00)
        print(torch_summarize(model))

        model.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

        sum_done = False
        for epoch in range(max_epochs):
            # Training
            model.train()
            for i, (skeleton, label) in enumerate(training_generator):

                skeleton = skeleton.reshape(-1, input_channels,
                                            seq_length).to(device)
                label = label.to(device)