def init_model(save_path, args):
    # Search for an existing model in the save directory
    if miscfuncs.file_check('model.json', save_path) and args.load_model:
        print('existing model file found, loading network')
        model_data = miscfuncs.json_load('model', save_path)
        # assertions to check that the model.json file is for the right neural network architecture
        try:
            assert model_data['model_data']['unit_type'] == args.unit_type
            assert model_data['model_data']['input_size'] == args.input_size
            assert model_data['model_data']['hidden_size'] == args.hidden_size
            assert model_data['model_data']['output_size'] == args.output_size
        except AssertionError:
            print(
                "model file found with network structure not matching config file structure"
            )
        network = networks.load_model(model_data)
    # If no existing model is found, create a new one
    else:
        print('no saved model found, creating new network')
        network = networks.SimpleRNN(input_size=args.input_size,
                                     unit_type=args.unit_type,
                                     hidden_size=args.hidden_size,
                                     output_size=args.output_size,
                                     skip=args.skip_con)
        network.save_state = False
        network.save_model('model', save_path)
    return network
Ejemplo n.º 2
0
    def test_forward(self):
        network = networks.SimpleRNN(input_size=1,
                                     output_size=1,
                                     unit_type="LSTM",
                                     hidden_size=32,
                                     skip=1,
                                     bias_fl=True,
                                     num_layers=1)
        run_net(network)

        network = networks.SimpleRNN(input_size=2,
                                     output_size=4,
                                     unit_type="GRU",
                                     hidden_size=16,
                                     skip=0,
                                     bias_fl=False,
                                     num_layers=1)
        run_net(network)

        network = networks.SimpleRNN(input_size=1,
                                     output_size=1,
                                     unit_type="GRU",
                                     hidden_size=16,
                                     skip=0,
                                     bias_fl=False,
                                     num_layers=1)
        run_net(network)

        network = networks.SimpleRNN(input_size=1,
                                     output_size=1,
                                     unit_type="GRU",
                                     hidden_size=1,
                                     skip=0,
                                     bias_fl=False,
                                     num_layers=1)
        run_net(network)

        network = networks.SimpleRNN(input_size=1,
                                     output_size=2,
                                     unit_type="RNN",
                                     hidden_size=16,
                                     skip=1,
                                     bias_fl=True,
                                     num_layers=2)
        run_net(network)
Ejemplo n.º 3
0
    def test_save_model(self):
        network = networks.SimpleRNN(input_size=1,
                                     output_size=1,
                                     unit_type='LSTM',
                                     hidden_size=8,
                                     skip=1)
        model_save_test(network)
        del network

        network = networks.SimpleRNN(input_size=1,
                                     output_size=1,
                                     unit_type='LSTM',
                                     hidden_size=8,
                                     skip=0)
        model_save_test(network)
        del network

        network = networks.SimpleRNN(input_size=1,
                                     output_size=1,
                                     unit_type='GRU',
                                     num_layers=2,
                                     skip=1)
        model_save_test(network)
        del network

        network = networks.SimpleRNN(input_size=1,
                                     output_size=1,
                                     unit_type='GRU',
                                     num_layers=2,
                                     skip=1,
                                     bias_fl=False)
        model_save_test(network)
        del network

        network = networks.SimpleRNN(input_size=5,
                                     output_size=2,
                                     unit_type='GRU',
                                     num_layers=2,
                                     skip=2)
        model_save_test(network)
        del network
Ejemplo n.º 4
0
    dataset = dataset.DataSet(data_dir=args.data_location)
    dataset.create_subset('train', frame_len=args.seg_len)
    dataset.create_subset('val')
    dataset.create_subset('test')

    #dataset = dataset.DataSet('data')
    #dataset.load_file('train/BehPhaserToneoffSingles1', 'train')
    #dataset.load_file('val/BehPhaserToneoffSingles1', 'val')
    #dataset.load_file('test/BehPhaserToneoffSingles1', 'test')

    dataset.load_file(args.file_name, ['train', 'val', 'test'],
                      [0.75, 0.125, 0.125])

    # Create instance of Network.RNN class
    network = networks.SimpleRNN(hidden_size=args.hidden_size,
                                 num_layers=args.num_layers,
                                 unit_type=args.unit_type,
                                 input_size=args.in_size)

    # Otherwise create directory for output
    if not args.cur_epoch:
        try:
            os.mkdir(save_path)
        except FileExistsError:
            shutil.rmtree(save_path)
            os.mkdir(save_path)

    if cuda:
        network = network.cuda()

    optimiser = torch.optim.Adam(network.parameters(), lr=args.learn_rate)
    loss_functions = training.LossWrapper(args.loss_fcns, args.pre_filt)