def main():
	params = parser.parse_args()
	params.start_ctx = "START" if params.use_start_ctx else None
	params.type = [] if params.type is None else params.type
	params.with_ctx = not params.no_ctx	
	mn_params = {}
	orig_params_file = None
	save_dir = os.path.join(params.root, "runs", params.name)
	assert(len(params.name) > 0), "[RUN]: ERROR - please submit a legitimate unique run id for the process run to start"
	if params.mode == "train":
		train_mission = True
		gen_mission = False
		params_file = None
		if params.cont:
			assert(os.path.exists(save_dir)), "[RUN]: ERROR - can't continue training from non-existing run folder '" + str(save_dir) + "'"
			assert(params.model is not None), "[RUN]: ERROR - must supply a saved model to continue training"
			if os.path.exists(os.path.join(save_dir, "params.txt")):
				params_file = os.path.join(save_dir, "params.txt")
		else:
			assert(not os.path.exists(save_dir)), "[RUN]: ERROR - the name specified for this run already exists, please specify a new unique name or delete the folder named \'" + str(params.name) + "' in subdir \'runs\'"
			params_file = params.config
		if params_file is not None:
			with open(params_file, mode = "r") as pf:
				if params.cont:
					mn_params = json.loads(pf.read())
				else:
					mn_params = ast.literal_eval(pf.read())
				orig_params_file = copy.deepcopy(mn_params)
		os.makedirs(os.path.join(params.root, "runs"), exist_ok = True)		
	elif params.mode in ["generate"]:
		train_mission = False
		gen_mission = True
		assert(os.path.exists(save_dir)), "[RUN]: ERROR - specified path '" + str(save_dir) + "' does not name a run folder in '" + str(os.path.join(params.root, "runs")) + "'"
		params_path = os.path.join(save_dir, "params.txt")
		if os.path.exists(params_path):
			with open(params_path, mode = "r") as pf:
				mn_params = json.loads(pf.read())

	dp = DataProcessor()
	
	if train_mission:
		if len(mn_params) > 0:
			dp.setup_training(root_dir = params.root, validation_set_ratio = mn_params["validation_set_ratio"], random_ordering = mn_params["use_randomness"])
			assert(mn_params["generation"]["sequence_base"] == dp.sequence_base), "[RUN]: different sequence base in prepared training directory and in input parameters: \"" + str(mn_params["generation"]["sequence_base"]) + "\" and \"" + str(dp.sequence_base) + "\""
		else:
			dp.setup_training(root_dir = params.root)
			mn_params["generation"]["sequence_base"] = dp.sequence_base
			
	new_params = {
		"num_durations" : dp.mp.num_durations,
		"num_pitches" : dp.mp.NUM_PITCHES,
		"num_instruments" : dp.mp.NUM_INSTRUMENTS,
		"num_beats" : dp.mp.NUM_BEATS,
		"num_special_tokens": dp.mp.NUM_SPECIAL_TOKENS,
		"modelling_properties": {
			"offset": {
				"indices": (dp.mp.time_a, dp.mp.time_a + dp.mp.num_durations)
			},
			"beats": {
				"indices": (dp.mp.beats_a, dp.mp.beats_a + dp.mp.NUM_BEATS)
			},
			"duration": {
				"indices": (dp.mp.dur_a, dp.mp.dur_a + dp.mp.num_durations)
			},
			"pitch": {
				"indices": (dp.mp.inp_a, dp.mp.inp_a + dp.mp.NUM_PITCHES)
			},
			"instrument": {
				"indices": (dp.mp.instr_a, dp.mp.instr_a + dp.mp.NUM_INSTRUMENTS)
			},
			"active_pitches": {
				"indices": (dp.mp.act_notes_a, dp.mp.act_notes_a + dp.mp.NUM_PITCHES)
			},
			"active_instruments": {
				"indices": (dp.mp.act_inst_a, dp.mp.act_inst_a + dp.mp.NUM_INSTRUMENTS)
			},
			"special_tokens": {
				"indices": (dp.mp.final_tokens_a, dp.mp.final_tokens_a + dp.mp.NUM_SPECIAL_TOKENS)
			}
		},
		"generation": {
			"default_durations": np.reshape(dp.mp.default_durations, [1, 1, dp.mp.default_durations.shape[0]]),
			"default_bar": dp.mp.default_bar_unit_duration,
			"default_beats": np.reshape(dp.mp.default_beats_vector, [1, 1, dp.mp.default_beats_vector.shape[0]]),
			"default_duration_set": dp.mp.default_duration_set,
			"default_duration_sets": [0, 0, dp.mp.default_duple_duration_set, dp.mp.default_triple_duration_set, dp.mp.default_common_duration_set],
			"default_duplet_duration": "d8",
			"default_triplet_duration": "t8"
		}
	}
	if train_mission:
		new_params["generation"]["ctx_length"] = dp.ctx_length
		new_params["generation"]["inp_length"] = dp.inp_length
	def update_dict(orig, add):
		for k in add:
			if isinstance(add[k], dict):
				if k not in orig:
					orig[k] = add[k]
				else:
					assert(isinstance(orig[k], dict)), "[RUN]: ERROR - while merging parameters for MahlerNet, dict key \"" + k + "\" should be a dict but is not in config file"
					orig[k] = update_dict(orig[k], add[k])
			else:
				orig[k] = add[k]
		return orig
	mn_params = update_dict(mn_params, new_params)
	mn_params["root_dir"] = params.root
	mn_params["model_name"] = params.name
	mn_params["save_dir"] = save_dir
	mn = MahlerNet(mn_params)
	mn.build()
	input = None
	empty = False
	ctxs = 0
	zs = 0
	if params.units is not None:
		empty, input, ctxs, zs = parse_units(params.units)
		input.sort()
		if params.file is not None:
			generator = dp.data_generator(params.file + ".pickle", mn.params["root_dir"], "data", mn.params["generation"]["sequence_base"], mn.params["generation"]["ctx_length"], mn.params["generation"]["inp_length"])
			input = fetch_units(input, generator)
		else:
			assert(len(input) == 0), "[RUN]: ERROR - units are given for generation but no file was supplied, only '-' is allowed as a position without an input file"
		if empty:
			input = [("empty", None, None)] + input			
	elif params.file is not None:
		generator = dp.data_generator(params.file + ".pickle", mn.params["root_dir"], "data", mn.params["generation"]["sequence_base"], mn.params["generation"]["ctx_length"], mn.params["generation"]["inp_length"])
	
	# check conditions applying to both training and generating before proceeding
	assert("recon" not in params.type or len(input) > 0 and params.file is not None), "[RUN]: ERROR - must supply file and > 0 input units to generate by reconstruction with option 'recon'"
	assert("pred" not in params.type or empty or (input is not None and len(input) > 0 and params.file is not None)), "[RUN]: ERROR - must supply file and > 0 input units or '-' unit to generate by prediction with option 'pred'"
	
	# if units were given, input is now a list of tuples were each tuple is on the form (c, i) where c, i = None if that property should use default (if modelled) and other wise (bool, prop) if prop should be used (if modelled).
	# the bool indicates whether to process the input with latent() or context() first or if it is the processed version given. empty might be set to true indicating the use of the all empty input (None, None) in which case
	# sampling of z and initial ctx are used.
	if train_mission:
		assert("intpol" not in params.type and "n_pred" not in params.type and "n_recon" not in params.type), "[RUN]: ERROR - only 'recon' and 'pred' are available for continuous generation while training"
		print("[RUN]: training with", dp.total, "samples with a maximum of", dp.max_context_length, "steps in context and", dp.max_input_length, "in input")
		if params.model is not None and params.cont:
			mn.load_model(params.model)
		generator_fn = lambda batch: (batch_generator(dp, batch, False, params.root, mn_params["use_randomness"], params.max_limit), batch_generator(dp, batch, True, params.root, mn_params["use_randomness"],  params.max_limit))
		eoe_fns = []
		if "recon" in params.type:
			eoe_fns += [lambda model, epoch, step: reconstruction_test(model, dp, "epoch" + str(epoch) + "_step" + str(step), input, params.samples, params.meter, params.use_triplets, params.start_ctx, params.use_teacher_forcing)]
		if "pred" in params.type:
			eoe_fns += [lambda model, epoch, step: prediction_test(model, dp, "epoch" + str(epoch) + "_step" + str(step), input, params.samples, params.meter, params.length, params.use_triplets, params.start_ctx)]
		os.makedirs(os.path.join(mn.params["save_dir"]), exist_ok = True)
		if not os.path.exists(os.path.join(save_dir, "commands.txt")):
			with open(os.path.join(save_dir, "commands.txt"), 'w'): 
				pass
		with open(os.path.join(save_dir, "commands.txt"), "a") as f:
			f.write(" ".join(sys.argv))				
		with open(os.path.join(mn.params["save_dir"], 'all_params.txt'), 'w') as all_params_file: # save ALL params of the current setup before starting
			all_params_file.write(json.dumps(dict_values_to_str(mn.params), sort_keys = False, indent = 4))
		if orig_params_file is not None:
			with open(os.path.join(mn.params["save_dir"], 'params.txt'), 'w') as params_file: # save input params for reference and future runs of the current setup before starting
				params_file.write(json.dumps(orig_params_file, sort_keys=False, indent = 4))
		(epoch_losses, step_losses, epoch_prec_rec, step_prec_rec, epoch_dist, step_dist) = mn.train(generator_fn, dp.total, dp.sz_training_set, dp.sz_validation_set, init_vars = params.model is None, eoe_fns = eoe_fns)	
		if mn.params["save_stats"]:
			os.makedirs(os.path.join(mn.params["save_dir"], "records"))	
			path = os.path.join(mn.params["save_dir"], "records", "stats")
			file = open(path, mode = "xb")
			pickle.dump((epoch_losses, step_losses, epoch_prec_rec, step_prec_rec, epoch_dist, step_dist), file)
			print("[RUN]: saved training statistics to", path)
			file.close()		
	else:
		assert(params.model is not None), "[RUN]: ERROR - must specify a model name within the 'saved_models' folder of the root directory to generate something"
		assert(params.type is not None),  "[RUN]: ERROR - must specify at least one type of generation to generate from a loaded model"
		with open(os.path.join(save_dir, "commands.txt"), "a") as f:
			f.write(" ".join(sys.argv))			
		mn.load_model(params.model)
		if "recon" in params.type:
			assert(mn.params["model"]["vae"]), "[RUN]: ERROR - must have a model that includes a vae to run reconstruction tests"
			reconstruction_test(mn, dp, mn.params["model_name"], input, params.samples, params.meter, params.use_triplets, params.start_ctx, params.use_teacher_forcing)
		if "n_recon" in params.type:
			assert(mn.params["model"]["vae"]), "[RUN]: ERROR - must have a model that includes a vae to run reconstruction tests"
			if params.file is not None: # use the input file for the generator
				r_gen = dp.data_generator(params.file + ".pickle", mn.params["root_dir"], "data", mn.params["generation"]["sequence_base"], mn.params["generation"]["ctx_length"], mn.params["generation"]["inp_length"])
			else: # use the training generator, must adapt its output however
				_, _, f = dp.setup_dirs(mn.params["root_dir"], "data", "pickle", files = None)
				def generator_converter(filenames):
					for filename in filenames:
						for (ctx, inp) in dp.data_generator(filename, mn.params["root_dir"], "data", mn.params["generation"]["sequence_base"], mn.params["generation"]["ctx_length"], mn.params["generation"]["inp_length"]):
							yield (ctx, inp)					  		  
				r_gen = generator_converter(f)
			n_reconstruction_test(mn, dp, mn.params["model_name"], r_gen, params.with_ctx, params.start_ctx, params.use_teacher_forcing)
		if "pred" in params.type:
			prediction_test(mn, dp, params.file, input, params.samples, params.meter, params.length, params.use_triplets, params.start_ctx)
		if "n_pred" in params.type:
			assert(zs <= 1 and ctxs > 1), "[RUN]: ERROR - specify several context units, each followed by 'c' and at most one unit to use for input, to use the n_pred generation type"
			assert(mn.params["model"]["ctx"] and mn.params["model"]["vae"]), "[RUN]: ERROR - must use a model that uses both a vae and a context to use the \"n_pred\" generation type"
			z_units = [(n, i) for (n, c, i) in input if i is not None]
			ctx_units = [(n, c) for (n, c, i) in input if c is not None]
			n_prediction_test(mn, dp, params.file, ctx_units, z_units[0] if len(z_units) > 0 else None, params.meter, params.length, params.use_triplets)
		if "intpol" in params.type:
			assert(zs == 2 and ctxs <= 1), "[RUN]: ERROR - must supply exactly two input units ('-' not counted) and an optional context to use for interpolating between latent states"
			z_units = [(n, i) for (n, c, i) in input if i is not None]
			ctx_units = [(n, c) for (n, c, i) in input if c is not None]
			interpolation_test(mn, dp, params.file, ctx_units[0] if len(ctx_units) > 0 else params.start_ctx, z_units, params.use_slerp, params.steps, params.meter, params.use_triplets, params.start_ctx)