Example #1
0
def main(args):
    if args.model == 'LSTM':
        model = LSTM(input_dim=args.input_dim, lstm_hidden_dim=args.lstm_hidden_dim, time_step=args.time_step)
    else:
        raise ValueError
    model.cuda()

    if args.optim == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'SGD_momentum':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
    elif args.optim == 'Adagrad':
        optimizer = torch.optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'RMSprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, alpha=0.999, eps=1e-8, weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay)
    else:
        raise ValueError

    lr_scheduler = None
    if args.load_path:
        if args.recover:
            load_model(model, args.load_path, strict=True)
            print('load model state dict in {}'.format(args.load_path))

    map_file_path = 'divide.csv'
    data_file_path = 'processed_data.txt'
    social_economical_path = '2010-2016.csv'
    if args.dataset == 'NaiveDataset':
        train_set = NaiveDataset(data_file_path, map_file_path)
    elif args.dataset == 'AdvancedDataset':
        train_set = AdvancedDataset(data_file_path, map_file_path, social_economical_path)
    else:
        raise ValueError
    train_dataloader = DataLoader(train_set, batch_size=args.batch_size, 
                                  shuffle=True, num_workers=args.num_workers,
                                  pin_memory=True)

    if args.evaluate:
        validate(train_dataloader, model)
        return

    train(train_dataloader, train_dataloader, model, optimizer, lr_scheduler, args)
Example #2
0
# list of uppercase letters to init bandnames
uppers = string.ascii_uppercase

rnn_type = str(sys.argv[1])

if rnn_type == 'milstm':
    decoder = miLSTM(n_characters, hidden_size, 64, n_characters)
elif rnn_type == 'lstm':
    decoder = LSTM(n_characters, hidden_size, 64, n_characters)
else:
    decoder = LN_miLSTM(n_characters, hidden_size, 64, n_characters)

decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

decoder.cuda()

death_metal_bands = pd.read_csv('../data/death-metal/bands.csv')

band_raw = death_metal_bands['name'].tolist()

band_nms = []
for i, bnd in enumerate(band_raw):
    band_nms.append(bnd + '<EOS>')

print('Found', len(band_nms), 'bands!')


def time_since(since):
    s = time.time() - since
    m = math.floor(s / 60)