CUDA_AVAILABLE = args.cuda
print('Running on GPU: ', CUDA_AVAILABLE)

# initialize dataset, dataloader and created batched data
BAND = 'combined'
SEGMENT = '2'
NUM_DATA_POINTS = 1000
NUM_BATCHES = 100
METRIC = args.metric 
file_name = BAND + '_' + str(SEGMENT) + '_data'
if sys.version_info[0] < 3:
    data_path = 'dat/' + file_name + '.dill'
else:
    data_path = 'dat/' + file_name + '_3.dill'
dataset = PitchContourDataset(data_path)
dataloader = PitchContourDataloader(dataset, NUM_DATA_POINTS, NUM_BATCHES) 
training_data, _, validation_data, _, _, _ = dataloader.create_split_data(chunk_len=2000, hop=500)

# initialize the model
if args.model == 'pitchrnn':
    perf_model = PitchRnn()
elif args.model == 'pitchcrnn':
    perf_model = PitchCRnn()
if CUDA_AVAILABLE:
    perf_model.cuda()
print(perf_model)

# define loss criterion
criterion = nn.MSELoss()

# initialize training hyperparamaters
if torch.cuda.is_available():
    perf_model.cuda()
    perf_model.load_state_dict(torch.load('saved/' + filename + '.pt'))
else:
    perf_model.load_state_dict(
        torch.load('saved/' + filename + '.pt',
                   map_location=lambda storage, loc: storage))

# initialize dataset, dataloader and created batched data
file_name = BAND + '_' + str(SEGMENT) + '_data'
if sys.version_info[0] < 3:
    data_path = 'dat/' + file_name + '.dill'
else:
    data_path = 'dat/' + file_name + '_3.dill'
dataset = PitchContourDataset(data_path)
dataloader = PitchContourDataloader(dataset, NUM_DATA_POINTS, NUM_BATCHES)
_, _, vef, _, tef = dataloader.create_split_data(1000, 500)
# test on full length data
test_loss, test_r_sq, test_accu, test_accu2, pred, target = eval_utils.eval_model(
    perf_model, criterion, tef, METRIC, MTYPE, CTYPE, 1)
print('[%s %0.5f, %s %0.5f, %s %0.5f %0.5f]' %
      ('Testing Loss: ', test_loss, ' R-sq: ', test_r_sq, ' Accu:', test_accu,
       test_accu2))

# convert to numpy
if torch.cuda.is_available():
    pred = pred.clone().cpu().numpy()
    target = target.clone().cpu().numpy()
else:
    pred = pred.clone().numpy()
    target = target.clone().numpy()
    mast_path = '/Users/Som/GitHub/Mastmelody_dataset/f0data'
else:
    if torch.cuda.is_available():
        data_path = '/home/data_share/FBA/fall19/data/pitch_contour/' + BAND + '_2_pc_3.dill'
    else:
        data_path = '/Volumes/Farren/python_stuff/dat/' + BAND + '_2_data_3.dill'

    mast_path = '/home/apati/MASTmelody_dataset/f0data'

if BAND == 'mast':
    dataset = MASTDataset(mast_path)
    dataloader = MASTDataloader(dataset)
    CTYPE = 1
else:
    dataset = PitchContourDataset(datasets_all[instrument])
    dataloader = PitchContourDataloader(dataset, NUM_DATA_POINTS, NUM_BATCHES)


tr1, v1, vef, te1, tef = dataloader.create_split_data(1000, 500) #1000, 500 | 1500, 500 | 2000, 1000
tr2, v2, _, te2, _ = dataloader.create_split_data(1500, 500)
tr3, v3, _, te3, _ = dataloader.create_split_data(2000, 1000)
#tr4, v4, _, te4, _ = dataloader.create_split_data(2500, 1000)
#tr5, v5, _, te5, _ = dataloader.create_split_data(3000, 1500)
#tr6, v6, vef, te6, tef = dataloader.create_split_data(4000, 2000)
training_data = tr1 + tr2 + tr3 #+ tr2 + tr3 #+ tr4 + tr5 + tr6     # this is the proper training data split
validation_data = vef #+ v2 + v3 + v4 + v5 + v6
testing_data = te1 + te2 + te3 #+ te4 + te5 + te6


## augment data
aug_training_data = train_utils.augment_data(training_data)