def train(args): print(args.train) print(args.val) # get data X_train, Y_train = load_h5(args.train) X_val, Y_val = load_h5(args.val) print(X_train.shape, Y_train.shape) print(X_val.shape, Y_val.shape) # determine super-resolution level n_dim, n_chan = Y_train[0].shape r = Y_train[0].shape[1] / X_train[0].shape[1] assert n_chan == 1 # create model model = get_model(args, n_dim, r, from_ckpt=False, train=True) total_parameters = 0 for variable in tf.trainable_variables(): shape = variable.get_shape() print(variable.name, shape, len(shape)) variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print('total_parameters', total_parameters) # train model model.fit(X_train, Y_train, X_val, Y_val, n_epoch=args.epochs)
def train(args): # get data X_train, Y_train, Z_train = load_h5(args.train) X_val, Y_val, Z_val = load_h5(args.val) # X_train, Y_train = load_h5(args.train) # X_val, Y_val = load_h5(args.val) print('++++++++++++++++++++++') print(Z_train.shape) print(Z_train[0]) print(Z_train[3]) print('++++++++++++++++++++++') # determine super-resolution level n_dim, n_chan = Y_train[0].shape r = Y_train[0].shape[1] / X_train[0].shape[1] assert n_chan == 1 # create model model = get_model(args, n_dim, r, from_ckpt=False, train=True) # load model # model = get_model(args, n_dim, r, from_ckpt=True, train=True) # model.load(args.logname) # from default checkpoint # train model model.fit(X_train, Y_train, Z_train, X_val, Y_val, Z_val, n_epoch=args.epochs)
def train(args): # get data X_train, Y_train = load_h5(args.train) X_val, Y_val = load_h5(args.val) # determine super-resolution level n_dim, n_chan = Y_train[0].shape r = Y_train[0].shape[1] / X_train[0].shape[1] assert n_chan == 1 # create model model = get_model(args, n_dim, r, from_ckpt=False, train=True) # train model model.fit(X_train, Y_train, X_val, Y_val, n_epoch=args.epochs)
def spline(args): # Load data if(args.grocery == 'false'): X_val, Y_val = load_h5(args.val) else: X_val = pickle.load(open("../data/grocery/grocery/grocery-test-data_" + args.val)) Y_val = pickle.load(open("../data/grocery/grocery/grocery-test-label" + args.val)) for i in range(len(X_val)): urow = upsample(X_val[i,:], 1) X_val[i,:] = urow for (t, P, Y) in [("val", X_val, Y_val)]: if(args.piano=='true'): axes =(1,2) else: axes = 1 sqrt_l2_loss =np.sqrt(np.mean((P-Y)**2 + 1e-6, axis=axes)) sqrn_l2_norm = np.sqrt(np.mean(Y**2, axis=axes)) snr = 20 * np.log(sqrn_l2_norm/sqrt_l2_loss + 1e-8) / np.log(10.) avg_snr = np.mean(snr, axis=0) lsd = compute_log_distortion(np.reshape(Y, (-1)), np.reshape(P, (-1))) avg_sqrt_l2_loss = np.mean(sqrt_l2_loss, axis=0) print(t + " l2 loss: " + str(avg_sqrt_l2_loss)) print(t + " average SNR: " + str(avg_snr)) print(t + " lsd: " + str(lsd))
def train(args): full = True if args.full == 'true' else False # get data if(args.grocery == 'false'): X_train, Y_train = load_h5(args.train) X_val, Y_val = load_h5(args.val) else: X_train = pickle.load(open("../data/grocery/grocery/grocery-train-data" + args.train)) Y_train = pickle.load(open("../data/grocery/grocery/grocery-train-label" + args.train)) X_val = pickle.load(open("../data/grocery/grocery/grocery-test-data_" + args.train)) Y_val = pickle.load(open("../data/grocery/grocery/grocery-test-label" + args.train)) X_train = np.reshape(X_train, [X_train.shape[0], X_train.shape[1], 1]) Y_train = np.reshape(Y_train, [Y_train.shape[0], Y_train.shape[1], 1]) X_val = np.reshape(X_val, [X_val.shape[0], X_val.shape[1], 1]) Y_val = np.reshape(Y_val, [Y_val.shape[0], Y_val.shape[1], 1]) # reshape piano data if args.piano == 'true': X_train = np.reshape(X_train, [X_train.shape[0], X_val.shape[2], X_val.shape[1]]) Y_train = np.reshape(Y_train, [Y_train.shape[0], Y_val.shape[2], Y_val.shape[1]]) X_val = np.reshape(X_val, [X_val.shape[0], X_val.shape[2], X_val.shape[1]]) Y_val = np.reshape(Y_val, [Y_val.shape[0], Y_val.shape[2], Y_val.shape[1]]) # determine super-resolution level print(Y_train[0].shape) print(Y_train[0]) n_dim, n_chan = Y_train[0].shape r = Y_train[0].shape[1] / X_train[0].shape[1] assert n_chan == 1 # Train seq2seq model if(args.model == 'seq2seq'): model = models.Model2() model.run(X_train, Y_train, X_val, Y_val, n_epoch=args.epochs, r=args.r, speaker=args.speaker, grocery=args.grocery) else: # create model model = get_model(args, n_dim, r, from_ckpt=False, train=True) # train model model.fit(X_train, Y_train, X_val, Y_val, n_epoch=args.epochs, r=args.r, speaker=args.speaker, grocery=args.grocery, piano=args.piano, calc_full_snr = full)
def train(args): # get data X_train, Y_train = load_h5(args.train) X_val, Y_val = load_h5(args.val) # determine super-resolution level assert X_val.shape[1] % X_train.shape[1] == 0 n_dim, n_chan = Y_train[0].shape r = Y_train.shape[1] / X_train.shape[1] X_train = upsample_training_data(X_train, int(r)) X_val = upsample_training_data(X_val, int(r)) assert n_chan == 1 # create model if args.from_ckpt == None: model = get_model(args, n_dim, r, from_ckpt=False, train=True) else: model = get_model(args, n_dim, r, from_ckpt=True, train=True) model.load(args.from_ckpt) # from default checkpoint # train model model.fit(X_train, Y_train, X_val, Y_val, n_epoch=args.epochs)