Exemple #1
0
def run():
    data = data_loader.load(DATASET,
                            n_train=N_TRAIN,
                            n_test=N_TEST,
                            train_noise=TRAIN_NOISE,
                            test_noise=TEST_NOISE)

    stratify = DATASET not in ["abalone", "segment"]

    if DATASET not in [
            'arcene', 'moon', 'toy_Story', 'toy_Story_ood', 'segment'
    ]:
        print(DATASET)
        x = data_loader.prepare_inputs(data['features'])
        y = data['labels']
        x_train, x_test, y_train, y_test = train_test_split(
            x,
            y,
            train_size=TRAIN_TEST_RATIO,
            stratify=y if stratify else None,
            random_state=0)

    else:
        if DATASET == 'moon' or DATASET == 'toy_Story' or DATASET == 'toy_Story_ood':
            x_train, x_test = data['x_train'], data['x_val']
        else:
            x_train, x_test = data_loader.prepare_inputs(
                data['x_train'], data['x_val'])
        y_train, y_test = data['y_train'], data['y_val']

    # Generate validation split
    x_train, x_val, y_train, y_val = train_test_split(
        x_train,
        y_train,
        train_size=TRAIN_TEST_RATIO,
        stratify=y_train if stratify else None,
        random_state=0)

    x_train = x_train.astype(np.float32)
    x_val = x_val.astype(np.float32)
    x_test = x_test.astype(np.float32)

    if 'N_OOD' in globals() and N_OOD >= 1:
        n_ood = update_n_ood(data, DATASET, N_OOD)
        n_ood = y_val.shape[1] - n_ood - 1
        print("Number of ood classes: {n_ood}")
        x_train, x_val, x_test, y_train, y_val, y_test, x_ood, y_ood = prepare_ood(
            x_train, x_val, x_test, y_train, y_val, y_test, n_ood, NORM)
        # x_test_with_ood = np.concatenate([x_test, x_ood], axis=0)
        # y_test_with_ood = np.concatenate([y_test, y_ood], axis=0)
        x_ood_val, x_ood_test, y_ood_val, y_ood_test = train_test_split(
            x_ood, y_ood, test_size=0.5, random_state=0)
        x_test_with_ood = np.concatenate([x_test, x_ood_test], axis=0)
        y_test_with_ood = np.concatenate([y_test, y_ood_test], axis=0)
        x_val_with_ood = np.concatenate([x_val, x_ood_val], axis=0)
        y_val_with_ood = np.concatenate([y_val, y_ood_val], axis=0)
    else:
        n_ood = 0
    print('Finish loading data')
    gdrive_rpath = './experiments_all'

    t = int(time.time())
    log_dir = os.path.join(gdrive_rpath, MODEL_NAME, '{}'.format(t))
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
    file_writer_cm = tf.summary.create_file_writer(log_dir + '/cm')

    checkpoint_filepath = os.path.join(log_dir, 'ckpt')
    if not os.path.exists(checkpoint_filepath):
        os.makedirs(checkpoint_filepath)

    model_path = os.path.join(log_dir, 'model')
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    model_cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        save_weights_only=True,
        monitor=MONITOR,
        mode='max',
        save_best_only=True,
        verbose=1)

    model = build_model(x_train.shape[1], y_train.shape[1], MODEL, args)

    def plot_boundary(epoch, logs):
        # Use the model to predict the values from the validation dataset.
        xy = np.mgrid[-10:10:0.1, -10:10:0.1].reshape(2, -1).T
        hat_z = tf.nn.softmax(model(xy, training=False), axis=1)
        # scipy.special.softmax(hat_z, axis=1)
        c = np.sum(np.arange(hat_z.shape[1] + 1)[1:] * hat_z, axis=1)
        # c = np.argmax(np.arange(6)[1:]*scipy.special.softmax(hat_z, axis=1), axis=1
        # xy = np.mgrid[-1:1.1:0.01, -2:2.1:0.01].reshape(2,-1).T
        figure = plt.figure(figsize=(8, 8))
        plt.scatter(xy[:, 0], xy[:, 1], c=c, cmap="brg")
        image = plot_to_image(figure)
        # Log the confusion matrix as an image summary.
        with file_writer_cm.as_default():
            tf.summary.image("Boundaries", image, step=epoch)

    border_callback = tf.keras.callbacks.LambdaCallback(
        on_epoch_end=plot_boundary)

    training_generator = mixup.data_generator(x_train,
                                              y_train,
                                              batch_size=BATCH_SIZE,
                                              n_channels=N_CHANNELS,
                                              shuffle=SHUFFLE,
                                              mixup_scheme=MIXUP_SCHEME,
                                              k=N_NEIGHBORS,
                                              alpha=ALPHA,
                                              local=LOCAL_RANDOM,
                                              out_of_class=OUT_OF_CLASS,
                                              manifold_mixup=MANIFOLD_MIXUP)

    validation_generator = mixup.data_generator(x_val,
                                                y_val,
                                                batch_size=x_val.shape[0],
                                                n_channels=N_CHANNELS,
                                                shuffle=False,
                                                mixup_scheme='none',
                                                alpha=0,
                                                manifold_mixup=MANIFOLD_MIXUP)

    test_generator = mixup.data_generator(x_test,
                                          y_test,
                                          batch_size=x_test.shape[0],
                                          n_channels=N_CHANNELS,
                                          shuffle=True,
                                          mixup_scheme='none',
                                          alpha=0,
                                          manifold_mixup=MANIFOLD_MIXUP)

    if N_OOD > 0:
        in_out_test_generator = mixup.data_generator(
            x_test_with_ood,
            y_test_with_ood,
            batch_size=x_test_with_ood.shape[0],
            n_channels=N_CHANNELS,
            shuffle=True,
            mixup_scheme='none',
            alpha=0,
            manifold_mixup=MANIFOLD_MIXUP)

    callbacks = [tensorboard_callback, model_cp_callback]
    if DATASET == 'Toy_story' or DATASET == 'Toy_story_ood':
        border_callback = tf.keras.callbacks.LambdaCallback(
            on_epoch_end=cb.plot_boundary)
        callbacks += [border_callback]
    if MODEL in ['jem', 'jemo', 'jehm', 'jehmo', 'jehmo_mix']:
        callbacks += [cb.jem_n_epochs()]

    ## buffer ##
    '''
    if MODEL in ['jehmo', 'jehmo_mix']:
        if model.with_buffer_out:
            model.replay_buffer_out = get_buffer(model.buffer_size,
                                                 training_generator.x.shape[1],
                                                 x=training_generator.x)
    '''
    ## training ##
    t_train_start = int(time.time())
    training_history = model.fit(x=training_generator,
                                 validation_data=validation_generator,
                                 epochs=EPOCHS,
                                 callbacks=callbacks)
    t_train_end = int(time.time())
    used_time = t_train_end - t_train_start
    model.load_weights(checkpoint_filepath)
    # model.save(model_path)
    print('Tensorboard callback directory: {}'.format(log_dir))

    ood_loss = 0
    metric_file = os.path.join(gdrive_rpath, 'results.txt')
    loss = model.evaluate(test_generator, return_dict=True)
    # if N_OOD>0:
    #    ood_loss = model.evaluate(in_out_test_generator, return_dict=True)
    # with open(metric_file, "a+") as f:
    #    f.write(f"{MODEL}, {DATASET}, {t}, {loss['acc_with_ood']:.3f}," \
    #            f"{loss['ece_metrics']:.3f}, {loss['oe_metrics']:.3f}," \
    #            f"{loss['loss']:.3f}, {n_ood}, {ood_loss['auc_of_ood']}\n")
    if N_OOD > 0:
        ood_loss = model.evaluate(in_out_test_generator, return_dict=True)
        with open(metric_file, "a+") as f:
            f.write(f"{MODEL}, {MIXUP_SCHEME}, {DATASET}, {t}, {loss['accuracy']:.3f}," \
                    f"{loss['ece_metrics']:.3f}, {loss['oe_metrics']:.3f}," \
                    f"{ood_loss['accuracy']:.3f}," \
                    f"{ood_loss['ece_metrics']:.3f}, {ood_loss['oe_metrics']:.3f},"
                    f"{n_ood}, {ood_loss['auc_of_ood']}, {used_time}\n")
    else:
        with open(metric_file, "a+") as f:
            f.write(f"{MODEL}, {MIXUP_SCHEME}, {DATASET}, {t}, {loss['accuracy']:.3f}," \
                    f"{loss['ece_metrics']:.3f}, {loss['oe_metrics']:.3f}," \
                    f"None, " \
                    f"None, None,"
                    f"{n_ood}, None, {used_time}\n")

    arg_file = os.path.join(log_dir, 'args.txt')
    with open(arg_file, "w+") as f:
        f.write(str(args))
Exemple #2
0
def run():
    data = data_loader.load(DATASET,
                            n_train=N_TRAIN,
                            n_test=N_TEST,
                            train_noise=TRAIN_NOISE,
                            test_noise=TEST_NOISE,
                            ood=OOD)

    stratify = DATASET not in ["abalone", "segment"]

    if DATASET not in [
            'arcene', 'moon', 'toy_Story', 'toy_Story_ood', 'segment'
    ]:
        print(DATASET)
        x = data_loader.prepare_inputs(data['features'])
        y = data['labels']
        '''
        # check whether the choice of N_OOD is reasonable
        classes = np.argmax(y, axis=1)
        number_of_each_class = [(classes == ic).sum() for ic in range(int(classes.max()))]
        number_of_each_class.reverse()
        percentage_of_each_class = np.cumsum(np.array(number_of_each_class)) / np.array(number_of_each_class).sum()
        n_ood = np.where(percentage_of_each_class>=0.1)[0][0] + 1

        #n_in = y.shape[1] - n_ood
        #stratify = classes < n_in
        '''
        x_train, x_test, y_train, y_test = train_test_split(
            x,
            y,
            train_size=TRAIN_TEST_RATIO,
            stratify=y if stratify else None)

    else:
        #n_ood = int(N_OOD)
        if DATASET == 'moon' or DATASET == 'toy_Story' or DATASET == 'toy_Story_ood':
            x_train, x_test = data['x_train'], data['x_val']
        else:
            x_train, x_test = data_loader.prepare_inputs(
                data['x_train'], data['x_val'])
        y_train, y_test = data['y_train'], data['y_val']

    if 'N_OOD' in globals() and N_OOD >= 1:
        n_ood = prepare_ood_from_args(data, DATASET, N_OOD)
        n_in = y_train.shape[1] - n_ood

        # training
        train_in_idxs = np.argmax(y_train, axis=1) < n_in
        train_ood_idxs = np.argmax(y_train, axis=1) >= n_in
        #val_in_idxs = np.argmax(y_val, axis=1) < n_in
        #val_ood_idxs = np.argmax(y_val, axis=1) >= n_in
        x_train_in = x_train[train_in_idxs]
        y_train_in = y_train[train_in_idxs][:, 0:n_in]
        x_train_out = x_train[train_ood_idxs]
        y_train_out = y_train[train_ood_idxs][:, 0:n_in]

        # Generate validation split
        x_train_in, x_val_in, y_train_in, y_val_in = train_test_split(
            x_train_in,
            y_train_in,
            train_size=TRAIN_TEST_RATIO,
            stratify=y_train_in if stratify else None)

        x_val = np.concatenate((x_train_out, x_val_in), axis=0)
        y_val = np.concatenate((y_train_out, y_val_in), axis=0)
        y_test = y_test[:, 0:n_in]
        y_val = y_val[:, 0:n_in]

        x_train = x_train_in.astype(np.float32)
        x_val = x_val.astype(np.float32)
    else:
        x_train, x_val, y_train, y_val = train_test_split(
            x_train,
            y_train,
            train_size=TRAIN_TEST_RATIO,
            stratify=y_train if stratify else None)

    #####################

    print('Finish loading data')
    gdrive_rpath = './experiments_ood'

    t = int(time.time())
    log_dir = os.path.join(gdrive_rpath, MODEL_NAME, '{}/logs'.format(t))
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
    file_writer_cm = tf.summary.create_file_writer(log_dir + '/cm')

    checkpoint_filepath = os.path.join(gdrive_rpath, MODEL_NAME,
                                       '{}/ckpt/'.format(t))
    if not os.path.exists(checkpoint_filepath):
        os.makedirs(checkpoint_filepath)

    model_path = os.path.join(gdrive_rpath, MODEL_NAME,
                              '{}/model'.format(format(t)))
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    model_cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        save_weights_only=True,
        monitor='val_auc_of_ood',
        mode='max',
        save_best_only=True)

    model = build_model(x_train.shape[1], y_train.shape[1], MODEL, args)

    def plot_boundary(epoch, logs):
        # Use the model to predict the values from the validation dataset.
        xy = np.mgrid[-10:10:0.1, -10:10:0.1].reshape(2, -1).T
        hat_z = tf.nn.softmax(model(xy, training=False), axis=1)
        # scipy.special.softmax(hat_z, axis=1)
        c = np.sum(np.arange(hat_z.shape[1] + 1)[1:] * hat_z, axis=1)
        # c = np.argmax(np.arange(6)[1:]*scipy.special.softmax(hat_z, axis=1), axis=1
        # xy = np.mgrid[-1:1.1:0.01, -2:2.1:0.01].reshape(2,-1).T
        figure = plt.figure(figsize=(8, 8))
        plt.scatter(xy[:, 0], xy[:, 1], c=c, cmap="brg")
        image = plot_to_image(figure)
        # Log the confusion matrix as an image summary.
        with file_writer_cm.as_default():
            tf.summary.image("Boundaries", image, step=epoch)

    def plot_boundary_pretrain(epoch, logs):
        # Use the model to predict the values from the validation dataset.
        xy = np.mgrid[-1:1.1:0.01, -2:2.1:0.01].reshape(2, -1).T
        hat_z = tf.nn.softmax(model(xy, training=False), axis=1)
        # scipy.special.softmax(hat_z, axis=1)
        c = np.sum(np.arange(6)[1:] * hat_z, axis=1)
        # c = np.argmax(np.arange(6)[1:]*scipy.special.softmax(hat_z, axis=1), axis=1
        # xy = np.mgrid[-1:1.1:0.01, -2:2.1:0.01].reshape(2,-1).T
        figure = plt.figure(figsize=(8, 8))
        plt.scatter(xy[:, 0], xy[:, 1], c=c, cmap="brg")
        image = plot_to_image(figure)
        # Log the confusion matrix as an image summary.
        with file_writer_cm.as_default():
            tf.summary.image("Boundaries_pretrain", image, step=epoch)

    border_callback_pretrain = tf.keras.callbacks.LambdaCallback(
        on_epoch_end=plot_boundary_pretrain)
    border_callback = tf.keras.callbacks.LambdaCallback(
        on_epoch_end=plot_boundary)

    training_generator = mixup.data_generator(x_train_in,
                                              y_train_in,
                                              batch_size=BATCH_SIZE,
                                              n_channels=N_CHANNELS,
                                              shuffle=SHUFFLE,
                                              mixup_scheme=MIXUP_SCHEME,
                                              k=N_NEIGHBORS,
                                              alpha=ALPHA,
                                              local=LOCAL_RANDOM,
                                              out_of_class=OUT_OF_CLASS,
                                              manifold_mixup=MANIFOLD_MIXUP)

    validation_generator = mixup.data_generator(x_val,
                                                y_val,
                                                batch_size=x_val.shape[0],
                                                n_channels=N_CHANNELS,
                                                shuffle=False,
                                                mixup_scheme='none',
                                                alpha=0,
                                                manifold_mixup=MANIFOLD_MIXUP)

    test_generator = mixup.data_generator(x_test,
                                          y_test,
                                          batch_size=x_test.shape[0],
                                          n_channels=N_CHANNELS,
                                          shuffle=False,
                                          mixup_scheme='none',
                                          alpha=0,
                                          manifold_mixup=MANIFOLD_MIXUP)

    # Pretraining
    # if DATASET=='toy_Story':
    #   pre_x = np.mgrid[-1:1.1:0.01, -2:2.1:0.01].reshape(2,-1).T
    #   pre_y = .2*np.ones(shape=[pre_x.shape[0], 5])
    #   model.fit(x=pre_x, y=pre_y, epochs=1, callbacks=[border_callback_pretrain])

    training_history = model.fit(
        x=training_generator,
        validation_data=validation_generator,
        epochs=EPOCHS,
        callbacks=[
            tensorboard_callback,
            model_cp_callback,
            # border_callback
        ],
    )

    print(model.summary())
    model.load_weights(checkpoint_filepath)
    model.save(model_path)
    print('Tensorboard callback directory: {}'.format(log_dir))

    metric_file = os.path.join(gdrive_rpath, MODEL_NAME,
                               '{}/results.txt'.format(t))
    loss = model.evaluate(test_generator, return_dict=True)
    test_outputs = model.predict(test_generator)
    with open(metric_file, "w") as f:
        f.write(str(loss))
def run():
    data = data_loader.load(DATASET,
                            n_train=N_TRAIN,
                            n_test=N_TEST,
                            train_noise=TRAIN_NOISE,
                            test_noise=TEST_NOISE)

    STRATIFY = DATASET not in ["abalone", "segment"]

    if DATASET not in [
            'arcene', 'moon', 'toy_story', 'toy_story_ood', 'segment'
    ]:
        x = data_loader.prepare_inputs(data['features'])
        y = data['labels']
        x_train, x_test, y_train, y_test = train_test_split(
            x,
            y,
            train_size=TRAIN_TEST_RATIO,
            stratify=y if STRATIFY else none)
    else:
        if DATASET == 'moon' or DATASET == 'toy_story' or DATASET == 'toy_story_ood':
            x_train, x_test = data['x_train'], data['x_val']
        else:
            x_train, x_test = data_loader.prepare_inputs(
                data['x_train'], data['x_val'])
        y_train, y_test = data['y_train'], data['y_val']

    # generate validation split
    x_train, x_val, y_train, y_val = train_test_split(
        x_train,
        y_train,
        train_size=TRAIN_TEST_RATIO,
        stratify=y_train if STRATIFY else None)

    x_train = x_train.astype(np.float32)
    x_val = x_val.astype(np.float32)
    x_test = x_test.astype(np.float32)

    # delete for categorical datasets
    n_mean = np.mean(x_train, axis=0)
    n_std = np.var(x_train, axis=0)**.5

    x_train = (x_train - n_mean) / n_std
    x_val = (x_val - n_mean) / n_std
    x_test = (x_test - n_mean) / n_std

    n_classes = y_val.shape[1]
    if N_OOD > 0 and n_classes > N_OOD:
        n_ood = n_classes - N_OOD - 1
        idx_train_ood = np.argmax(y_train, axis=1) > n_ood
        idx_train_in = np.argmax(y_train, axis=1) <= n_ood
        idx_test_ood = np.argmax(y_test, axis=1) > n_ood
        idx_test_in = np.argmax(y_test, axis=1) <= n_ood
        idx_val_ood = np.argmax(y_val, axis=1) > n_ood
        idx_val_in = np.argmax(y_val, axis=1) <= n_ood

        x_test_ood = x_test[idx_test_ood]
        y_test_ood = y_test[idx_test_ood][n_ood + 1:]
        x_train_ood = x_train[idx_train_ood]
        y_train_ood = y_train[idx_train_ood][n_ood + 1:]
        x_val_ood = x_val[idx_val_ood]
        y_val_ood = y_val[idx_val_ood][n_ood + 1:]

        x_train = x_train[idx_train_in]
        x_test = x_test[idx_test_in]
        x_val = x_val[idx_val_in]
        y_train = y_train[idx_train_in][:n_ood + 1]
        y_test = y_test[idx_test_in][:n_ood + 1]
        y_val = y_val[idx_val_in][:n_ood + 1]

    print('Finish loading data')
    gdrive_rpath = './experiments'

    t = int(time.time())
    log_dir = os.path.join(gdrive_rpath, MODEL_NAME, '{}'.format(t))
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
    file_writer_cm = tf.summary.create_file_writer(log_dir + '/cm')

    checkpoint_filepath = os.path.join(log_dir, 'ckpt')
    if not os.path.exists(checkpoint_filepath):
        os.makedirs(checkpoint_filepath)

    model_path = os.path.join(log_dir, 'model')
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    model_cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        save_weights_only=True,
        monitor=MONITOR,
        mode='max',
        save_best_only=True)

    model = build_model(x_train.shape[1], y_train.shape[1], MODEL, args)

    training_generator = mixup.data_generator(x_train,
                                              y_train,
                                              batch_size=BATCH_SIZE,
                                              n_channels=N_CHANNELS,
                                              shuffle=SHUFFLE,
                                              mixup_scheme=MIXUP_SCHEME,
                                              k=N_NEIGHBORS,
                                              alpha=ALPHA,
                                              local=LOCAL_RANDOM,
                                              out_of_class=OUT_OF_CLASS,
                                              manifold_mixup=MANIFOLD_MIXUP)

    validation_generator = mixup.data_generator(x_val,
                                                y_val,
                                                batch_size=x_val.shape[0],
                                                n_channels=N_CHANNELS,
                                                shuffle=False,
                                                mixup_scheme='none',
                                                alpha=0,
                                                manifold_mixup=MANIFOLD_MIXUP)

    test_generator = mixup.data_generator(x_test,
                                          y_test,
                                          batch_size=x_test.shape[0],
                                          n_channels=N_CHANNELS,
                                          shuffle=False,
                                          mixup_scheme='none',
                                          alpha=0,
                                          manifold_mixup=MANIFOLD_MIXUP)

    callbacks = [tensorboard_callback, model_cp_callback]
    if DATASET == 'Toy_story' or DATASET == 'Toy_story_ood':
        border_callback = tf.keras.callbacks.LambdaCallback(
            on_epoch_end=cb.plot_boundary)
        callbacks += [border_callback]
    if MODEL == 'jem':
        callbacks += [cb.jem_n_epochs()]

    training_history = model.fit(x=training_generator,
                                 validation_data=validation_generator,
                                 epochs=EPOCHS,
                                 callbacks=callbacks)

    model.load_weights(checkpoint_filepath)
    #model.save(model_path)
    print('Tensorboard callback directory: {}'.format(log_dir))

    metric_file = os.path.join(gdrive_rpath, 'results.txt')
    loss = model.evaluate(test_generator, return_dict=True)

    z_in = tf.nn.softmax(model(np.concatenate([x_test, x_val], axis=0)))
    c_in = tf.math.reduce_max(z_in, axis=-1)
    acc_in = tf.reduce_mean(
        tf.cast(tf.math.argmax(z_in, axis=-1) == y_test, tf.float32))

    z_out = tf.nn.softmax(
        model(np.concatenate([x_train_ood, x_test_ood, x_val_ood], axis=0)))
    c_out = tf.math.reduce_max(z_out, axis=-1)

    z_train = tf.nn.softmax(model(x_train))
    c_train = tf.math.reduce_max(z_train, axis=-1)

    # Plot histogram from confidences
    plt.hist(c_in, bins=20, color='blue', label='In')
    plt.hist(c_out, bins=20, color='red', label='Out')
    #plt.hist(c_train, density=True, bins=20, color='green', label='Train_in')
    plt.ylabel('Frequency')
    plt.xlabel('Confidence')
    plt.xlim([0, 1])
    plt.legend()
    plt.savefig(os.path.join(log_dir, 'confidence.png'), dpi=300)
    plt.close()
    plt.hist(c_in, density=True, bins=20, color='blue', label='Confidence')
    plt.hist(acc_in, density=True, bins=20, color='red', label='Accuracy')
    plt.ylabel('Fraction')
    plt.xlabel('Confidence')
    plt.xlim([0, 1])
    plt.legend()
    plt.savefig(os.path.join(log_dir, 'acc_conf.png'), dpi=300)
    plt.close()
    with open(metric_file, "a+") as f:
        f.write(f"{MODEL}, {DATASET}, {t}, {loss['accuracy']:.3f}," \
                f"{loss['ece_metrics']:.3f}, {loss['oe_metrics']:.3f}," \
                f"{loss['loss']:.3f}\n")

    arg_file = os.path.join(log_dir, 'args.txt')
    with open(arg_file, "w+") as f:
        f.write(str(args))