def run_thread(gpu, iters_to_calc, all_weights, shapes, train_y_shape, train_generator, val_generator, dim_sum, args, dsets, hf_grads):
    # each process writes to a different variable in the file
    grads_train_key = 'grads_train_{}'.format(gpu)
    grads_test_key = 'grads_test_{}'.format(gpu)

    # build model for this process/device
    with tf.device('/device:GPU:{}'.format(gpu)):
        if args.arch == 'linknet':
            model = network_builders.build_linknet()
        elif args.arch == 'fc':
            model = network_builders.build_network_fc(args)
        elif args.arch == 'fc_cust':
            model = network_builders.build_fc_adjustable(args)
        elif args.arch == 'lenet':
            model = network_builders.build_lenet_conv(args)
        elif args.arch == 'allcnn':
            model = network_builders.build_all_cnn(args)
        elif args.arch == 'resnet':
            model = network_builders.build_resnet(args)
        elif args.arch == 'vgg':
            model = network_builders.build_vgg_half(args)
        init_model(model, args)
        define_training(model, args)

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    sess.run(tf.global_variables_initializer())

    grads_train = np.zeros((args.default_num_splits + 1, dim_sum)) # all the ones needed for current iteration rk4
    grads_test = np.zeros((args.default_num_splits + 1, dim_sum))
    newsize = len(iters_to_calc) * args.default_num_splits + 1 # used to resize
    timerstart = time.time()

    # get 0th iteration
    cur_weights_flat = all_weights[iters_to_calc[0]]
    (grads_train[0], cur_train_loss, grads_test[0], cur_test_loss) = load_and_calculate(
        sess, model, split_and_shape(cur_weights_flat, shapes), train_y_shape, train_generator, val_generator, dim_sum, args)
    dsets['trainloss'][iters_to_calc[0]] = cur_train_loss
    dsets['testloss'][iters_to_calc[0]] = cur_test_loss
    dsets[grads_train_key][0] = grads_train[0]
    dsets[grads_test_key][0] = grads_test[0]
    grads_ind = 1 # current index to write to for gradients arrays

    for iterations in iters_to_calc:
        # get next iteration gradients and loss, so we have a ground truth loss
        next_weights_flat = all_weights[iterations + 1]
        (grads_train[-1], next_train_loss, grads_test[-1], next_test_loss) = load_and_calculate(
            sess, model, split_and_shape(next_weights_flat, shapes), train_y_shape, train_generator, val_generator, dim_sum, args)

        # get the middle fractional iterations
        get_fractional_gradients(range(1, args.default_num_splits), args.default_num_splits, cur_weights_flat,
            next_weights_flat, grads_train, grads_test, train_y_shape, train_generator, val_generator, shapes, sess, model, dim_sum, args) #1, or 1,2,3

        # tuple of (train loss diff, test loss diff)
        approx_errors = (calc_approx_error(cur_train_loss, next_train_loss, grads_train, next_weights_flat - cur_weights_flat),
            calc_approx_error(cur_test_loss, next_test_loss, grads_test, next_weights_flat - cur_weights_flat))
        num_splits = args.default_num_splits

        # do smaller splits until error is small enough
        while np.abs(approx_errors).max() > args.error_threshold and num_splits < 32:
            newsize += num_splits # need to resize by this much
            num_splits *= 2
            grads_train_halved = np.zeros((num_splits + 1, dim_sum))
            grads_test_halved = np.zeros((num_splits + 1, dim_sum))
            grads_train_halved[0:num_splits + 1:2] = grads_train # every odd index is zeros
            grads_test_halved[0:num_splits + 1:2] = grads_test

            # get quarter gradients, fill in the rest of grads_train_halved
            get_fractional_gradients(range(1, num_splits, 2), num_splits, cur_weights_flat, next_weights_flat,
                grads_train_halved , grads_test_halved, train_y_shape, train_generator, val_generator, shapes, sess, model, dim_sum, args) # 1,3 or 1,3,5,7
            grads_train = grads_train_halved
            grads_test = grads_test_halved
            approx_errors = (calc_approx_error(cur_train_loss, next_train_loss, grads_train, next_weights_flat - cur_weights_flat),
                calc_approx_error(cur_test_loss, next_test_loss, grads_test, next_weights_flat - cur_weights_flat))

        # actually writing to file
        dsets['trainloss'][iterations + 1] = next_train_loss
        dsets['testloss'][iterations + 1] = next_test_loss
        dsets['num_splits'][iterations] = num_splits
        if grads_ind + num_splits > dsets[grads_train_key].shape[0]: # resize when you have to
            dsets[grads_train_key].resize((newsize, dsets[grads_train_key].shape[1]))
            dsets[grads_test_key].resize((newsize, dsets[grads_test_key].shape[1]))
        dsets[grads_train_key][grads_ind:grads_ind + num_splits] = grads_train[1:] # 0 written in previous iteration
        dsets[grads_test_key][grads_ind:grads_ind + num_splits] = grads_test[1:]
        grads_ind += num_splits

        # set variables for next iteration
        cur_weights_flat = next_weights_flat
        cur_train_loss, cur_test_loss = next_train_loss, next_test_loss
        grads_train_new = np.zeros((args.default_num_splits + 1, dim_sum))
        grads_test_new = np.zeros((args.default_num_splits + 1, dim_sum))
        grads_train_new[0], grads_test_new[0] = grads_train[-1], grads_test[-1]
        grads_train, grads_test = grads_train_new, grads_test_new

        if (iterations - iters_to_calc[0]) % args.print_every == 0:
            print('iter {} from gpu {} ({:.2f} s)'.format(iterations, gpu, time.time() - timerstart))

    return gpu
Пример #2
0
def main():
    parser = make_parser()
    args = parser.parse_args()

    if args.tf_seed != -1:
        tf.random.set_random_seed(args.tf_seed)
    if not args.no_shuffle and args.shuffle_seed != -1:
        np.random.seed(args.shuffle_seed)

    # load data
    train_x, train_y = read_input_data(args.train_h5)
    test_x, test_y = read_input_data(args.test_h5) # used as val

    # SpaceNet
    all_ids = np.array(generate_ids(args.data_dirs, None))
    kfold = KFold(n_splits=2, shuffle=True)  # args.n_folds
    splits = [s for s in kfold.split(all_ids)]
    folds = [int(f) for f in '0'.split(",")]
    fold = folds[0]
    train_ind, test_ind = splits[fold]
    train_ids = all_ids[train_ind]
    val_ids = all_ids[test_ind]
    masks_dict = get_groundtruth(args.data_dirs)

    # Returns normalized to interval [-1, 1]
    train_generator = MULSpacenetDataset(
        data_dirs=args.data_dirs,
        wdata_dir=args.wdata_dir,
        image_ids=train_ids,
        batch_size=args.train_batch_size,
        crop_shape=(args.crop_size, args.crop_size),
        seed=777,
        image_name_template='PS-MS/SN3_roads_train_AOI_5_Khartoum_PS-MS_{id}.tif',
        masks_dict=masks_dict
    )

    val_generator = MULSpacenetDataset(
        data_dirs=args.data_dirs,
        wdata_dir=args.wdata_dir,
        image_ids=val_ids,
        batch_size=args.test_batch_size,
        crop_shape=(args.crop_size, args.crop_size),
        seed=777,
        image_name_template='PS-MS/SN3_roads_train_AOI_5_Khartoum_PS-MS_{id}.tif',
        masks_dict=masks_dict
    )

    # train_x in shape (batch_size, width, height, channels) = (train_batch_size, crop_size, crop_size, 12)
    # train_x, train_y = next(train_generator)
    # train_generator.reset()
    # test_x, test_y = next(val_generator)
    # val_generator.reset()
    # train_y_shape = train_y.shape


    images_scale = np.max(train_x)
    if images_scale > 1:
        print('Normalizing images by a factor of {}'.format(images_scale))
        train_x = train_x / images_scale
        test_x = test_x / images_scale

    if args.test_batch_size == 0:
        args.test_batch_size = test_y.shape[0]

    print('Data shapes:', train_x.shape, train_y.shape, test_x.shape, test_y.shape)

    if train_y.shape[0] % args.train_batch_size != 0:
        print("WARNING batch size doesn't divide train set evenly")
    if train_y.shape[0] % args.large_batch_size != 0:
        print("WARNING large batch size doesn't divide train set evenly")
    if test_y.shape[0] % args.test_batch_size != 0:
        print("WARNING batch size doesn't divide test set evenly")

    # build model
    if args.arch == 'linknet':
        model = network_builders.build_linknet()
    elif args.arch == 'fc':
        model = network_builders.build_network_fc(args)
    elif args.arch == 'fc_cust':
        model = network_builders.build_fc_adjustable(args)
    elif args.arch == 'lenet':
        model = network_builders.build_lenet_conv(args)
    elif args.arch == 'allcnn':
        model = network_builders.build_all_cnn(args)
    elif args.arch == 'resnet':
        model = network_builders.build_resnet(args)
    elif args.arch == 'vgg':
        model = network_builders.build_vgg_half(args)
    else:
        raise Error("Unknown architeciture {}".format(args.arch))

    init_model(model, args)
    define_training(model, args)

    sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
    sess.run(tf.global_variables_initializer())

    if args.init_weights_h5:
        load_initial_weights(sess, model, args)

    for collection in ['tb_train_step']: # 'eval_train' and 'eval_test' added manually later
        tf.summary.scalar(collection + '_acc', model.accuracy, collections=[collection])
        tf.summary.scalar(collection + '_loss', model.loss, collections=[collection])

    tb_writer, hf = None, None
    dsets = {}
    if args.output_dir:
        tb_writer = tf.summary.FileWriter(args.output_dir, sess.graph)
        # set up output for gradients/weights
        if args.save_weights:
            dim_sum = sum([tf.size(var).eval() for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)])
            total_iters = args.num_epochs * int(train_y.shape[0] / args.train_batch_size)
            total_chunks = int(total_iters / args.save_every)
            hf = h5py.File(args.output_dir + '/weights', 'w-')

            # write metadata
            var_shapes = np.string_(';'.join([str(var.get_shape()) for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)]))
            hf.attrs['var_shapes'] = var_shapes
            var_names = np.string_(';'.join([str(var.name) for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)]))
            hf.attrs['var_names'] = var_names

            # all individual weights at every iteration, where all_weights[i] = weights before iteration i:
            dsets['all_weights'] = hf.create_dataset('all_weights', (total_chunks + 1, dim_sum), dtype='f8', compression='gzip')
            print(f'all_weights shape: ({total_chunks + 1}, {dim_sum})')
        if args.save_training_grads:
            dsets['training_grads'] = hf.create_dataset('training_grads', (total_chunks, dim_sum), dtype='f8', compression='gzip')

    ########## Run main thing ##########
    print('=' * 100)
    train_and_eval(sess, model, train_x, train_y, test_x, test_y, tb_writer, dsets, args)
    # train_and_eval(sess, model, train_y_shape, train_generator, val_generator, tb_writer, dsets, args)

    if tb_writer:
        tb_writer.close()
    if hf:
        hf.close()
Пример #3
0
def main():
    parser = make_parser()
    args = parser.parse_args()

    if args.tf_seed != -1:
        tf.random.set_random_seed(args.tf_seed)
    if not args.no_shuffle and args.shuffle_seed != -1:
        np.random.seed(args.shuffle_seed)

    # load data
    train_x, train_y = read_input_data(args.train_h5)
    test_x, test_y = read_input_data(args.test_h5)  # used as val

    images_scale = np.max(train_x)
    if images_scale > 1:
        print('Normalizing images by a factor of {}'.format(images_scale))
        train_x = train_x / images_scale
        test_x = test_x / images_scale

    if args.test_batch_size == 0:
        args.test_batch_size = test_y.shape[0]

    print('Data shapes:', train_x.shape, train_y.shape, test_x.shape,
          test_y.shape)
    if train_y.shape[0] % args.train_batch_size != 0:
        print("WARNING batch size doesn't divide train set evenly")
    if train_y.shape[0] % args.large_batch_size != 0:
        print("WARNING large batch size doesn't divide train set evenly")
    if test_y.shape[0] % args.test_batch_size != 0:
        print("WARNING batch size doesn't divide test set evenly")

    # build model
    if args.arch == 'fc':
        model = network_builders.build_network_fc(args)
    elif args.arch == 'fc_cust':
        model = network_builders.build_fc_adjustable(args)
    elif args.arch == 'lenet':
        model = network_builders.build_lenet_conv(args)
    elif args.arch == 'allcnn':
        model = network_builders.build_all_cnn(args)
    elif args.arch == 'resnet':
        model = network_builders.build_resnet(args)
    elif args.arch == 'vgg':
        model = network_builders.build_vgg_half(args)
    else:
        raise Error("Unknown architeciture {}".format(args.arch))

    init_model(model, args)
    define_training(model, args)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    if args.init_weights_h5:
        load_initial_weights(sess, model, args)

    for collection in ['tb_train_step'
                       ]:  # 'eval_train' and 'eval_test' added manually later
        tf.summary.scalar(collection + '_acc',
                          model.accuracy,
                          collections=[collection])
        tf.summary.scalar(collection + '_loss',
                          model.loss,
                          collections=[collection])

    tb_writer, hf = None, None
    dsets = {}
    if args.output_dir:
        tb_writer = tf.summary.FileWriter(args.output_dir, sess.graph)
        # set up output for gradients/weights
        if args.save_weights:
            dim_sum = sum([
                tf.size(var).eval()
                for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            ])
            total_iters = args.num_epochs * int(
                train_y.shape[0] / args.train_batch_size)
            total_chunks = int(total_iters / args.save_every)
            hf = h5py.File(args.output_dir + '/weights', 'w-')

            # write metadata
            var_shapes = np.string_(';'.join([
                str(var.get_shape())
                for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            ]))
            hf.attrs['var_shapes'] = var_shapes
            var_names = np.string_(';'.join([
                str(var.name)
                for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            ]))
            hf.attrs['var_names'] = var_names

            # all individual weights at every iteration, where all_weights[i] = weights before iteration i:
            dsets['all_weights'] = hf.create_dataset(
                'all_weights', (total_chunks + 1, dim_sum),
                dtype='f8',
                compression='gzip')
        if args.save_training_grads:
            dsets['training_grads'] = hf.create_dataset(
                'training_grads', (total_chunks, dim_sum),
                dtype='f8',
                compression='gzip')

    ########## Run main thing ##########
    print('=' * 100)
    train_and_eval(sess, model, train_x, train_y, test_x, test_y, tb_writer,
                   dsets, args)

    if tb_writer:
        tb_writer.close()
    if hf:
        hf.close()
Пример #4
0
def main():
    parser = make_parser()
    args = parser.parse_args()

    #Define params for model:
    SEED = args.tf_seed
    BATCH_SIZE = args.train_batch_size
    CHANNEL_SIZE = args.input_dim[2]
    NUM_EPOCHS = args.num_epochs
    IMG_DIM = args.input_dim[0]
    TRAIN_DIR = 'train/'
    TEST_DIR = 'test/'
    CLASSES = {
        0: "No DR",
        1: "Mild",
        2: "Moderate",
        3: "Severe",
        4: "Proliferative DR"
    }

    df_train = pd.read_csv(os.path.join(args.data_dir, "train.csv"))
    df_test = pd.read_csv(os.path.join(args.data_dir, "test.csv"))

    print("Training set has {} samples".format(df_train.shape[0]))
    print("Testing set has {} samples".format(df_test.shape[0]))

    #Process image directories into exact file name (include .png):
    def append_ext(fn):
        return fn + ".png"

    df_train["id_code"] = df_train["id_code"].apply(append_ext)

    # load data into generator:
    # For some reason the generator wants diagnostic labels in string form:
    df_train['diagnosis'] = df_train['diagnosis'].astype(str)

    _validation_split = 0.20

    #x_train_shape = (int(np.round(df_train.shape[0] * (1 - _validation_split))), IMG_DIM, IMG_DIM, CHANNEL_SIZE)
    #x_test_shape = (int(np.round(df_train.shape[0] * _validation_split)), IMG_DIM, IMG_DIM, CHANNEL_SIZE)
    y_train_shape = (int(np.round(df_train.shape[0] *
                                  (1 - _validation_split))), None)
    y_test_shape = (int(np.round(df_train.shape[0] * _validation_split)), None)

    train_datagen = ImageDataGenerator(rescale=1. / 255,
                                       validation_split=_validation_split)

    train_generator = train_datagen.flow_from_dataframe(
        dataframe=df_train,
        directory=args.data_dir + TRAIN_DIR,
        x_col="id_code",
        y_col="diagnosis",
        batch_size=BATCH_SIZE,
        class_mode="categorical",
        target_size=(IMG_DIM, IMG_DIM),
        subset='training',
        seed=SEED)

    val_generator = train_datagen.flow_from_dataframe(
        dataframe=df_train,
        directory=args.data_dir + TRAIN_DIR,
        x_col="id_code",
        y_col="diagnosis",
        batch_size=BATCH_SIZE,
        class_mode="categorical",
        target_size=(IMG_DIM, IMG_DIM),
        subset='validation',
        seed=SEED)

    # build model
    if args.arch == 'basic':
        model = network_builders.build_basic_model(args)
    elif args.arch == 'fc':
        model = network_builders.build_network_fc(args)
    elif args.arch == 'fc_cust':
        model = network_builders.build_fc_adjustable(args)
    elif args.arch == 'lenet':
        model = network_builders.build_lenet_conv(args)
    elif args.arch == 'allcnn':
        model = network_builders.build_all_cnn(args)
    elif args.arch == 'resnet':
        model = network_builders.build_resnet(args)
    elif args.arch == 'vgg':
        model = network_builders.build_vgg_half(args)
    else:
        raise Error("Unknown architeciture {}".format(args.arch))

    # get all_weights. Do it in 1 chunk if it fits into memory
    hf_weights = h5py.File(args.weights_h5, 'r')
    if args.stream_inputs:
        all_weights = hf_weights[
            'all_weights']  # future work: change to streamds if you want it to be faster
    else:
        all_weights = np.array(hf_weights['all_weights'], dtype='f8')
    shapes = [
        literal_eval(s)
        for s in hf_weights.attrs['var_shapes'].decode('utf-8').split(';')
    ]

    print(all_weights.shape)
    print(shapes)
    num_iters = min(args.max_iters, all_weights.shape[0] - 1)
    dim_sum = all_weights.shape[1]

    # set up output file
    output_name = args.output_h5
    if not output_name:  # use default gradients name
        assert args.weights_h5[-8:] == '/weights'
        output_name = args.weights_h5[:-8] + '/gradients_adaptive'
        if args.max_iters < all_weights.shape[0] - 1:
            output_name += '_{}iters'.format(args.max_iters)
    print('Writing gradients to file {}'.format(output_name))
    dsets = {}
    hf_grads = h5py.File(output_name, 'w-')
    dsets['trainloss'] = hf_grads.create_dataset('trainloss',
                                                 (num_iters + 1, ),
                                                 dtype='f4',
                                                 compression='gzip')
    dsets['testloss'] = hf_grads.create_dataset('testloss', (num_iters + 1, ),
                                                dtype='f4',
                                                compression='gzip')
    dsets['num_splits'] = hf_grads.create_dataset('num_splits', (num_iters, ),
                                                  dtype='i',
                                                  compression='gzip')

    pool = ThreadPool(args.num_gpus)
    iters_to_calc = divide_with_remainder(num_iters, args.num_gpus)
    results = []
    overall_timerstart = time.time()

    for gpu in range(args.num_gpus):
        # each process writes to a different variable in the file
        dsets['grads_train_{}'.format(gpu)] = hf_grads.create_dataset(
            'grads_train_{}'.format(gpu),
            (len(iters_to_calc[gpu]) * args.default_num_splits + 1, dim_sum),
            maxshape=(None, dim_sum),
            dtype='f4',
            compression='gzip')
        dsets['grads_test_{}'.format(gpu)] = hf_grads.create_dataset(
            'grads_test_{}'.format(gpu),
            (len(iters_to_calc[gpu]) * args.default_num_splits + 1, dim_sum),
            maxshape=(None, dim_sum),
            dtype='f4',
            compression='gzip')

        if args.num_gpus > 1:
            ret = pool.apply_async(
                run_thread,
                (gpu, iters_to_calc[gpu], all_weights, shapes, y_train_shape,
                 y_test_shape, generator, dim_sum, args, dsets, hf_grads))
            results.append(ret)
        else:
            run_thread(gpu, iters_to_calc[gpu], all_weights, shapes,
                       y_train_shape, train_generator, y_test_shape,
                       val_generator, dim_sum, args, dsets, hf_grads)

    pool.close()
    pool.join()
    print('return values: ', [res.get() for res in results])
    print('total time elapsed:', time.time() - overall_timerstart)
    hf_weights.close()
    hf_grads.close()
Пример #5
0
def main():
    parser = make_parser()
    args = parser.parse_args()

    if args.tf_seed != -1:
        tf.random.set_random_seed(args.tf_seed)

    if not args.no_shuffle and args.shuffle_seed != -1:
        np.random.seed(args.shuffle_seed)

    #Define params for model:
    SEED = args.tf_seed
    BATCH_SIZE = args.train_batch_size
    CHANNEL_SIZE = args.input_dim[2]
    NUM_EPOCHS = args.num_epochs
    IMG_DIM = args.input_dim[0]
    TRAIN_DIR = 'train/'
    TEST_DIR = 'test/'
    CLASSES = {
        0: "No DR",
        1: "Mild",
        2: "Moderate",
        3: "Severe",
        4: "Proliferative DR"
    }

    df_train = pd.read_csv(os.path.join(args.data_dir, "train.csv"))
    df_test = pd.read_csv(os.path.join(args.data_dir, "test.csv"))

    print("Training set has {} samples".format(df_train.shape[0]))
    print("Testing set has {} samples".format(df_test.shape[0]))

    #Process image directories into exact file name (include .png):
    def append_ext(fn):
        return fn + ".png"

    df_train["id_code"] = df_train["id_code"].apply(append_ext)

    # load data into generator:
    # For some reason the generator wants diagnostic labels in string form:
    df_train['diagnosis'] = df_train['diagnosis'].astype(str)

    _validation_split = 0.20

    #x_train_shape = (int(np.round(df_train.shape[0] * (1 - _validation_split))), IMG_DIM, IMG_DIM, CHANNEL_SIZE)
    #x_test_shape = (int(np.round(df_train.shape[0] * _validation_split)), IMG_DIM, IMG_DIM, CHANNEL_SIZE)
    y_train_shape = (int(np.round(df_train.shape[0] *
                                  (1 - _validation_split))), None)
    y_test_shape = (int(np.round(df_train.shape[0] * _validation_split)), None)

    train_datagen = ImageDataGenerator(rescale=1. / 255,
                                       validation_split=_validation_split)

    train_generator = train_datagen.flow_from_dataframe(
        dataframe=df_train,
        directory=args.data_dir + TRAIN_DIR,
        x_col="id_code",
        y_col="diagnosis",
        batch_size=BATCH_SIZE,
        class_mode="categorical",
        target_size=(IMG_DIM, IMG_DIM),
        subset='training',
        seed=SEED)

    val_generator = train_datagen.flow_from_dataframe(
        dataframe=df_train,
        directory=args.data_dir + TRAIN_DIR,
        x_col="id_code",
        y_col="diagnosis",
        batch_size=BATCH_SIZE,
        class_mode="categorical",
        target_size=(IMG_DIM, IMG_DIM),
        subset='validation',
        seed=SEED)

    # build model
    if args.arch == 'basic':
        model = network_builders.build_basic_model(args)
    elif args.arch == 'fc':
        model = network_builders.build_network_fc(args)
    elif args.arch == 'fc_cust':
        model = network_builders.build_fc_adjustable(args)
    elif args.arch == 'lenet':
        model = network_builders.build_lenet_conv(args)
    elif args.arch == 'allcnn':
        model = network_builders.build_all_cnn(args)
    elif args.arch == 'resnet':
        model = network_builders.build_resnet(args)
    elif args.arch == 'vgg':
        model = network_builders.build_vgg_half(args)
    else:
        raise Error("Unknown architeciture {}".format(args.arch))

    init_model(model, args)
    define_training(model, args)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    for collection in ['tb_train_step'
                       ]:  # 'eval_train' and 'eval_test' added manually later
        tf.summary.scalar(collection + '_acc',
                          model.accuracy,
                          collections=[collection])
        tf.summary.scalar(collection + '_loss',
                          model.loss,
                          collections=[collection])

    tb_writer, hf = None, None
    dsets = {}
    if args.output_dir:
        tb_writer = tf.summary.FileWriter(args.output_dir, sess.graph)
        # set up output for gradients/weights
        if args.save_weights:
            dim_sum = sum([
                tf.size(var).eval()
                for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            ])
            total_iters = args.num_epochs * int(
                y_train_shape[0] / args.train_batch_size)
            total_chunks = int(total_iters / args.save_every)
            hf = h5py.File(args.output_dir + '/weights', 'w-')

            # write metadata
            var_shapes = np.string_(';'.join([
                str(var.get_shape())
                for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            ]))
            hf.attrs['var_shapes'] = var_shapes
            var_names = np.string_(';'.join([
                str(var.name)
                for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            ]))
            hf.attrs['var_names'] = var_names

            # all individual weights at every iteration, where all_weights[i] = weights before iteration i:
            dsets['all_weights'] = hf.create_dataset(
                'all_weights', (total_chunks + 1, dim_sum),
                dtype='f8',
                compression='gzip')
        if args.save_training_grads:
            dsets['training_grads'] = hf.create_dataset(
                'training_grads', (total_chunks, dim_sum),
                dtype='f8',
                compression='gzip')

    train_and_eval(sess, model, y_train_shape, train_generator, y_test_shape,
                   val_generator, tb_writer, dsets, args)

    if tb_writer:
        tb_writer.close()
    if hf:
        hf.close()