Exemplo n.º 1
0
                     'The width of the policy conv layer.')

flags.DEFINE_integer('value_conv_width', 1,
                     'The width of the value conv layer.')

flags.DEFINE_integer('fc_width', 256 if go.N == 19 else 64,
                     'The width of the fully connected layer in value head.')

flags.DEFINE_integer('trunk_layers', go.N,
                     'The number of resnet layers in the shared trunk.')

flags.DEFINE_multi_integer(
    'lr_boundaries', [400000, 600000],
    'The number of steps at which the learning rate will decay')

flags.DEFINE_multi_float('lr_rates', [0.01, 0.001, 0.0001],
                         'The different learning rates')

flags.register_multi_flags_validator(
    ['lr_boundaries', 'lr_rates'],
    lambda flags: len(flags['lr_boundaries']) == len(flags['lr_rates']) - 1,
    'Number of learning rates must be exactly one greater than the number of boundaries'
)

flags.DEFINE_float('l2_strength', 1e-4,
                   'The L2 regularization parameter applied to weights.')

flags.DEFINE_float(
    'value_cost_weight', 1.0,
    'Scalar for value_cost, AGZ paper suggests 1/100 for '
    'supervised learning')
Exemplo n.º 2
0
    'final batch size will be = train_batch_size * num_tpu_cores')

flags.DEFINE_integer('conv_width', 128 if go.N == 19 else 32,
                     'The width of each conv layer in the shared trunk.')

flags.DEFINE_integer('fc_width', 256 if go.N == 19 else 64,
                     'The width of the fully connected layer in value head.')

flags.DEFINE_integer('trunk_layers', go.N,
                     'The number of resnet layers in the shared trunk.')

flags.DEFINE_multi_integer(
    'lr_boundaries', [10000, 30000],
    'The number of steps at which the learning rate will decay')

flags.DEFINE_multi_float('lr_rates', [2e-2, 1e-3, 1e-4],
                         'The different learning rates')

flags.DEFINE_float('l2_strength', 1e-4,
                   'The L2 regularization parameter applied to weights.')

flags.DEFINE_float('sgd_momentum', 0.9,
                   'Momentum parameter for learning rate.')

flags.DEFINE_string('model_dir', None, 'The working directory of the model')

# See www.moderndescartes.com/essays/shuffle_viz for discussion on sizing
flags.DEFINE_integer('shuffle_buffer_size', 20000,
                     'Size of buffer used to shuffle train examples.')

flags.DEFINE_bool('use_tpu', False, 'Whether to use TPU for training.')
Exemplo n.º 3
0
from absl import app, flags
from mlperf_logging import mllog
import os

flags.DEFINE_multi_float('lr_rates', None, 'lr rates')
flags.DEFINE_multi_float('lr_boundaries', None, 'learning rate boundaries')
flags.DEFINE_float('l2_strength', None, 'weight decay')
flags.DEFINE_integer('conv_width', None, 'conv width')
flags.DEFINE_integer('fc_width', None, 'fc width')
flags.DEFINE_integer('trunk_layers', None, 'trunk layers')
flags.DEFINE_float('value_cost_weight', None, 'value cost weight')
flags.DEFINE_integer('summary_steps', None, 'summary steps')
flags.DEFINE_integer('bool_features', None, 'bool features')
flags.DEFINE_string('input_features', None, 'input features')
flags.DEFINE_string('input_layout', None, 'input layout')
flags.DEFINE_integer('shuffle_buffer_size', None, 'shuffle buffer size')
flags.DEFINE_boolean('shuffle_examples', None, 'shuffle examples')
flags.DEFINE_integer('keep_checkpoint_max', None, 'keep_checkpoint_max')
flags.DEFINE_integer('train_batch_size', None, 'train_batch_size')

FLAGS = flags.FLAGS


def main(argv):
    mllogger = mllog.get_mllogger()
    mllog.config(filename="train.log")

    mllog.config(default_namespace="worker1",
                 default_stack_offset=1,
                 default_clear_line=False)
                     help='use fp16 compression during allreduce')
flags.DEFINE_integer('BATCHES_PER_ALLREDUCE', default=1,
                     help='number of batches processed locally before '
                          'executing allreduce across workers; it multiplies '
                          'total batch size.')
flags.DEFINE_boolean('USE_ADASUM', default=False,
                     help='use adasum algorithm to do reduction')
flags.DEFINE_integer('LOG_INTERVAL', default=10,
                     help='how many batches to wait before logging training status')

# Default settings from https://arxiv.org/abs/1706.02677.
flags.DEFINE_string('MODEL_NAME', 'MobileNetV2',
                    help='The name of the architecture to train.')
flags.DEFINE_string('DATASET_NAME', 'CIFAR100',
                    help='The name of the dataset to train.')
flags.DEFINE_multi_float('DATA_MEAN', [0.5071, 0.4867, 0.4408],
                         help='mean value of dataset')
flags.DEFINE_multi_float('DATA_STD', [0.2675, 0.2565, 0.2761],
                         help='standard deviation value of dataset')
flags.DEFINE_multi_integer('DATA_SHAPE', [3, 32, 32],
                           help='data dimension of dataset')
flags.DEFINE_list('BLOCK_ARGS', ['wm1.0_rn8_s1',
                                 't1_c16_n1_s1',
                                 't6_c24_n2_s1',
                                 't6_c32_n3_s2',
                                 't6_c64_n4_s2',
                                 't6_c96_n3_s1',
                                 't6_c160_n3_s2',
                                 't6_c320_n1_s1'],
                  help='argument of blocks in EfficientNet style')
flags.DEFINE_integer('BATCH_SIZE', default=128,
                     help='input batch size for training')
Exemplo n.º 5
0
flags.DEFINE_boolean(
    'load_full_weights', False,
    'Load full COCO pretrained weights including those of the last detection layers'
)

flags.DEFINE_multi_integer(
    'model_size', (608, 608),
    'Resolution of DNN input, must be the multiples of 32')
flags.DEFINE_integer('max_out_size', 1,
                     'maximum detected object amount of one class')
flags.DEFINE_float('iou_threshold', 0.5, 'threshold of non-max suppression')
flags.DEFINE_float('confid_threshold', 0.5, 'threshold of confidence')

flags.DEFINE_float('brightness_delta', 0.3,
                   'brightness_delta of data augmentation')
flags.DEFINE_multi_float('contrast_range', (0.5, 1.5),
                         'contrast_range of data augmentation')
flags.DEFINE_float('hue_delta', 0.2,
                   'hue_delta of data augmentation, only between (0, 0.5)')
flags.DEFINE_float('probability', 0.8, 'percentage of augmented images')

# Anchors of k-means threshold=0.98
_ANCHORS = [(77.0, 91.0), (89.0, 93.0), (83.0, 101.0), (95.0, 102.0),
            (92.0, 111.0), (104.0, 108.0), (98.0, 117.0), (110.0, 122.0),
            (127.0, 134.0)]

# Anchors of k-means threshold=0.99
# _ANCHORS = [(77.0, 92.0), (89.0, 93.0), (83.0, 101.0), (95.0, 101.0), (93.0, 111.0), (105.0, 109.0), (101.0, 119.0), (114.0, 125.0), (151.0, 152.0)]

# Default anchors of COCO
# _ANCHORS = [(10,13),(16,30),(33,23),(30,61),(62,45),(59,119),(116,90),(156,198),(373,326)]
Exemplo n.º 6
0
loss_fns = {
    'bce': losses.BCEWithLogits,
    'hinge': losses.Hinge,
    'was': losses.Wasserstein,
    'softplus': losses.Softplus
}

FLAGS = flags.FLAGS
# model and training
flags.DEFINE_enum('dataset', 'cifar10', ['cifar10', 'stl10'], "dataset")
flags.DEFINE_enum('arch', 'cnn32', net_G_models.keys(), "architecture")
flags.DEFINE_integer('total_steps', 100000, "total number of training steps")
flags.DEFINE_integer('batch_size', 128, "batch size")
flags.DEFINE_float('lr_G', 2e-4, "Generator learning rate")
flags.DEFINE_float('lr_D', 2e-4, "Discriminator learning rate")
flags.DEFINE_multi_float('betas', [0.0, 0.9], "for Adam")
flags.DEFINE_integer('n_dis', 5, "update Generator every this steps")
flags.DEFINE_integer('z_dim', 128, "latent space dimension")
flags.DEFINE_float('c', 0.1, "clip value")
flags.DEFINE_enum('loss', 'was', loss_fns.keys(), "loss function")
flags.DEFINE_integer('seed', 0, "random seed")
# logging
flags.DEFINE_integer('eval_step', 5000, "evaluate FID and Inception Score")
flags.DEFINE_integer('sample_step', 500, "sample image every this steps")
flags.DEFINE_integer('sample_size', 64, "sampling size of images")
flags.DEFINE_string('logdir', './logs/WGAN_CIFAR10_CNN', 'logging folder')
flags.DEFINE_bool('record', True, "record inception score and FID")
flags.DEFINE_string('fid_cache', './stats/cifar10.train.npz', 'FID cache')
# generate
flags.DEFINE_bool('generate', False, 'generate images')
flags.DEFINE_string('pretrain', None, 'path to test model')
Exemplo n.º 7
0
                  'Image crop size [height, width] for evaluation.')

flags.DEFINE_integer('eval_interval_secs', 60 * 5,
                     'How often (in seconds) to run evaluation.')

# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
# one could use different atrous_rates/output_stride during training/evaluation.
flags.DEFINE_multi_integer('atrous_rates', None,
                           'Atrous rates for atrous spatial pyramid pooling.')

flags.DEFINE_integer('output_stride', 16,
                     'The ratio of input to output spatial resolution.')

# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale test.
flags.DEFINE_multi_float('eval_scales', [1.0],
                         'The scales to resize images for evaluation.')

# Change to True for adding flipped images during test.
flags.DEFINE_bool('add_flipped_images', False,
                  'Add flipped images for evaluation or not.')

flags.DEFINE_integer(
    'quantize_delay_step', -1,
    'Steps to start quantized training. If < 0, will not quantize model.')

# Dataset settings.

flags.DEFINE_string('dataset', 'pascal_voc_seg',
                    'Name of the segmentation dataset.')

flags.DEFINE_string('eval_split', 'val',
Exemplo n.º 8
0
from absl import app, flags
from time import gmtime, strftime
import json

FLAGS = flags.FLAGS
flags.DEFINE_integer('N', 50, 'number of images')
flags.DEFINE_integer('seed', 42,
                     'random seed for sampling images from ImageNet')
flags.DEFINE_multi_integer('attack_size', [20, 20], 'size of sticker')
flags.DEFINE_integer('stride', 20, 'stride of sticker')
flags.DEFINE_integer('batchsize', 128, 'batch size')
flags.DEFINE_string('model', 'bagnet33', 'model being evaluated')
flags.DEFINE_string('clip_fn', 'tanh_linear', 'clipping function')
flags.DEFINE_string('param', 'a',
                    'clip(a*x + b). Which parameter are going to be tested')
flags.DEFINE_multi_float('param_list', None, 'list of parameters to be tested')
flags.DEFINE_float('fixed_param', None, 'the other fixed parameter')
flags.DEFINE_string('data_path', '/mnt/data/imagenet',
                    'directory where data are stored')
flags.DEFINE_string('output_root', '/mnt/data/clipping_params_searching/',
                    'directory for storing results')


def main(argv):
    """
    FLAGS.output_root/
        [NAME]/
            [NAME].log
            [Name].lst
    """
    assert FLAGS.param in ['a', 'b'], 'FLAGS.param must be either a or b'
Exemplo n.º 9
0
                    default='/usr/local/srv/tfrecords/train/*2007*.tfrecords',
                    help="Dataset glob for train")
flags.DEFINE_string('val_dataset',
                    default='/usr/local/srv/tfrecords/val/*2007*.tfrecords',
                    help="Dataset glob for validate")
flags.DEFINE_string('test_dataset',
                    default='/usr/local/srv/tfrecords/test/*2007*.tfrecords',
                    help="Dataset glob for test")
flags.DEFINE_string('anchors_path',
                    default='model_data/yolo_anchors.txt',
                    help="Anchors path")
flags.DEFINE_string('classes_path',
                    default='model_data/voc_classes.txt',
                    help="Classes Path")
flags.DEFINE_multi_float('learning_rate',
                         default=[1e-3, 1e-4],
                         lower_bound=0,
                         help="Learning rate")
flags.DEFINE_enum_class(
    'opt',
    default=None,
    enum_class=OPT,
    help="Select optimization, One of {'XLA','DEBUG','MKL'}")
flags.DEFINE_string('tpu_address', default=None, help="TPU address")
flags.DEFINE_bool('freeze', default=False, help="Whether freeze backbone")
flags.DEFINE_bool('prune', default=False, help="Whether prune model")


def parse_tuple(val):
    if isinstance(val, str):
        return tuple([int(num) for num in val[1:-1].split(',')])
    return tuple(val)
Exemplo n.º 10
0
    'Parent directory in which output TFRecords will be saved.')
flags.DEFINE_string('data_dir', None,
                    '<data_dir>/<task> contains output of create_splits.py.')
flags.DEFINE_enum('loader_cls', 'Pfam34Loader', ['Pfam34Loader'],
                  'The loader class to fetch the unpaired sequences with.')
flags.DEFINE_string('task', 'iid_ood_clans',
                    'Task for which to generate TFRecords.')
flags.DEFINE_string('split', 'train',
                    'Data split for which to generate TFRecords.')
flags.DEFINE_integer('max_len', 512,
                     'Maximum sequence length, including any special tokens.')
flags.DEFINE_multi_string(
    'index_keys', ['fam_key', 'ci_100'],
    'Indexing keys for stratified sampling of sequence pairs.')
flags.DEFINE_multi_float(
    'smoothing', [1.0, 1.0],
    'Smoothing coefficients for stratified sampling of sequence pairs.')
flags.DEFINE_string(
    'branch_key', 'ci_100',
    'Branching key for stratified sampling of sequence pairs.')
flags.DEFINE_integer('seed', 0, 'PRNG seed to generate the shard with.')
flags.DEFINE_integer('n_pairs', 102400, 'Number of sequence pairs per shard.')
FLAGS = flags.FLAGS

LOADERS = {
    'Pfam34Loader': specs.make_pfam34_loader,
}


def make_serialize_fn():
    """Creates a serialization function for paired examples."""
Exemplo n.º 11
0
flags.DEFINE_string('dataset_path', None, 'TFRecord dataset path.')
flags.DEFINE_integer('learner_iterations_per_call', 500,
                     'Iterations per learner run call.')
flags.DEFINE_integer('policy_save_interval', 10000, 'Policy save interval.')
flags.DEFINE_integer('eval_interval', 10000, 'Evaluation interval.')
flags.DEFINE_integer('summary_interval', 1000, 'Summary interval.')
flags.DEFINE_integer('num_gradient_updates', 1000000,
                     'Total number of train iterations to perform.')
flags.DEFINE_float(
    'reward_shift', 0.0, 'Value to add to reward. Useful for sparse rewards, '
    'e.g. set to -0.5 for optimal performance on AntMaze environments which '
    'have rewards of 0 (most often) or 1 (when the target position is reached)'
)
flags.DEFINE_multi_float(
    'action_clipping', None, 'Optional (min, max) values to clip actions. '
    'e.g. set to (-0.995, 0.995) when actions are close to -1 and 1 since'
    'tanh_distribution.log_prob(actions) will yield -inf and inf and make '
    'actor loss NaN. ')
flags.DEFINE_bool(
    'use_trajectories', False,
    'Whether dataset samples are stored as trajectories. '
    'If False, stored as transitions')
flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')
flags.DEFINE_multi_string('gin_param', None, 'Gin binding parameters.')


@gin.configurable
def train_eval(
        root_dir,
        dataset_path,
        env_name,
Exemplo n.º 12
0
    'each move.')
flags.DEFINE_float(
    'mlperf_value_init_penalty', -1.0,
    'New children value initialization penalty. '
    'Child value = parent\'s value - penalty * color, '
    'clamped to [-1, 1].  Penalty should be in [0.0, 2.0]. '
    '0 is init-to-parent, 2.0 is init-to-loss [default]. '
    'This behaves similiarly to Leela\'s FPU '
    '"First Play Urgency".')
flags.DEFINE_float('mlperf_holdout_pct', -1.0,
                   'Fraction of games to hold out for validation.')
flags.DEFINE_float('mlperf_disable_resign_pct', -1.0,
                   'Fraction of games to disable resignation for.')
flags.DEFINE_multi_float(
    'mlperf_resign_threshold', [0.0, 0.0],
    'Each game\'s resign threshold is picked randomly '
    'from the range '
    '[min_resign_threshold, max_resign_threshold)')
flags.DEFINE_integer(
    'mlperf_parallel_games', -1,
    'Number of games to play concurrently on each selfplay '
    'thread. Inferences from a thread\'s concurrent games are '
    'batched up and evaluated together. Increasing '
    'concurrent_games_per_thread can help improve GPU or '
    'TPU utilization, especially for small models.')
flags.DEFINE_integer('mlperf_virtual_losses', -1,
                     'Number of virtual losses when running tree search')
flags.DEFINE_float(
    'mlperf_gating_win_rate', -1.0,
    'Win pct against the target model to define a converged '
    'model.')
Exemplo n.º 13
0
flags.DEFINE_float('max_v', 2500, 'Maximum value on the y/z-axis.')
flags.DEFINE_float('min_v', 0, 'Minimum value on the y/z-axis.')

# The title for each parameter vector passed above.
flags.DEFINE_string('title1', 'p1', 'Title for p1.')
flags.DEFINE_string('title2', 'p2', 'Title for p2.')
flags.DEFINE_string('title3', 'p3', 'Title for p3.')

# Allows skipping of interpolating just to do plotting.
flags.DEFINE_bool(
    'visualize_only', False, 'Will only do visualization. '
    'If passed must specify --precomputed_interpolation')
flags.DEFINE_string('precomputed_interpolation', None,
                    'File path to a precomputed interpolation.')

# Policy and environment details.
flags.DEFINE_string('env', 'Hopper-v1', 'Name of Environment.')
flags.DEFINE_integer('env_seed', 0, 'Seed for the environment.')
flags.DEFINE_integer('global_seed', 1, 'Seed for all the rngs.')
flags.DEFINE_integer('save_every', 10, 'Save results after these many epochs.')
flags.DEFINE_integer('batch_size', 16, 'Number of environments to run.')
flags.DEFINE_integer('n_trajectories', 128, 'Number of trajectories to use.')
flags.DEFINE_integer(
    'max_steps_env', 1500, 'Maximum number of steps to run in the environment '
    'before termination.')
flags.DEFINE_float('std', None, 'Standard deviations for the policy.')
flags.DEFINE_multi_float('stds', None,
                         'A list of standard deviations for the policy.')
flags.DEFINE_string('policy_type', 'normal',
                    'Type of policy. Either `discrete` or `normal`.')
Exemplo n.º 14
0
flags.DEFINE_string("baseline_dir", None,
                    "Path to directory containing the results of the baseline experiments.")

flags.DEFINE_string("filter", "",
                    "Run only on experiment names that contain this string")

flags.DEFINE_integer("n_bins", 100,
                     "Number of bins for plot binning")

flags.DEFINE_integer("n_bins_eff_suc", None,
                     "Number of bins for plot binning")

flags.DEFINE_integer("dpi", 100,
                     "DPI for saving")

flags.DEFINE_multi_float("x_lim", [-0.01, 0.21],
                         "X lim for eff_suc plot")

flags.DEFINE_multi_float("y_lim", [0.75, 1.01],
                         "X lim for eff_suc plot")

flags.DEFINE_float("x_lim_eff", None,
                   "X lim for eps_eff plot")

flags.DEFINE_float("y_lim_eff", None,
                   "X lim for eps_eff plot")

flags.DEFINE_float("zoom_factor", None,
                   "Zoom factor for zoom box")

flags.DEFINE_boolean("y_median", False,
                     "Aggregate y by median instead of mean")
Exemplo n.º 15
0
def avsr_flags():

    # Generic flags
    flags.DEFINE_integer('buffer_size', 15000, 'Shuffle buffer size')
    flags.DEFINE_integer('batch_size', 64, 'Batch Size')
    flags.DEFINE_integer('embedding_size', 128, 'Embedding dimension')
    flags.DEFINE_integer('beam_width', 10, 'Beam Width')
    flags.DEFINE_boolean('enable_function', False, 'Enable Function?')
    flags.DEFINE_string('architecture', 'transformer', 'Network Architecture')
    flags.DEFINE_string('gpu_id', '0', 'GPU index')
    flags.DEFINE_string('input_modality', 'audio',
                        'Switch between A and V inputs')

    flags.DEFINE_integer('noise_level', 0, 'Noise level in range {0, 1, 2, 3}')
    flags.DEFINE_boolean('mix_noise', False, 'TBA')

    flags.DEFINE_multi_integer('cnn_filters', (16, 32, 48, 64),
                               'Number of CNN filters per layer')
    flags.DEFINE_integer('cnn_dense_units', 256,
                         'Number of neurons in the CNN output layer')
    flags.DEFINE_string('cnn_activation', 'relu',
                        'Activation function in CNN layers')
    flags.DEFINE_string('cnn_final_activation', None,
                        'Activation function in the final CNN layer')
    flags.DEFINE_string('cnn_normalisation', 'layer_norm',
                        'Normalisation function in CNN blocks')
    flags.DEFINE_boolean('cnn_final_clip', True,
                         'Clip the activation of the final CNN layer')

    # flags.DEFINE_float('learning_rate', 0.001, 'Learning rate')
    flags.DEFINE_integer(
        'lr_warmup_steps', 750,
        'Number of steps for the Learning rate linear warmup')
    flags.DEFINE_float('max_gradient_norm', 1.0,
                       'Clip the global gradient norm')
    flags.DEFINE_string('optimiser', 'adam', 'Optimiser type')
    flags.DEFINE_boolean('amsgrad', False, 'Use AMSGrad ?')

    flags.DEFINE_string('logfile', 'default_logfile', 'Logfile name')
    flags.DEFINE_boolean('profiling', False,
                         'Enable profiling for TensorBoard')
    flags.DEFINE_boolean('write_halting_history', False,
                         'Dump segmentations to praat files')
    flags.DEFINE_boolean('plot_alignments', False,
                         'Write alignment summaries in TensorBoard')
    flags.DEFINE_boolean('use_tensorboard', False,
                         'Export TensorBoard summary')

    # RNN seq2seq FLAGS
    flags.DEFINE_string('encoder', 'RNN', 'Encoder type')
    flags.DEFINE_string('recurrent_activation', 'sigmoid',
                        'Activation function inside LSTM Cell')
    flags.DEFINE_string('encoder_input_normalisation', 'layer_norm',
                        'Normalisation function for the Encoder input')
    flags.DEFINE_string('cell_type', 'ln_lstm', 'Recurrent cell type')
    flags.DEFINE_multi_integer(
        'encoder_units_per_layer', 3 * (256, ),
        'Number of encoder cells in each recurrent layer')
    flags.DEFINE_multi_integer(
        'decoder_units_per_layer', 1 * (256, ),
        'Number of decoder cells in each recurrent layer')
    flags.DEFINE_multi_float('dropout_probability', 3 * (0.1, ),
                             'Dropout rate for for RNN cells')
    flags.DEFINE_multi_float(
        'rnn_l1_l2_weight_regularisation', (0.0, 0.0001),
        'Weight regularisation (L1 and L2) for RNN cells')
    flags.DEFINE_boolean('recurrent_regularisation', False,
                         'Use regularisation in the recurrent LSTM kernel')
    flags.DEFINE_boolean('regularise_all', True,
                         'Regularise all model variables ?')
    flags.DEFINE_string('recurrent_initialiser', None,
                        'Recurrent kernel initialiser')
    flags.DEFINE_boolean('recurrent_dropout', True,
                         'Apply dropout on recurrent state')
    flags.DEFINE_boolean('enable_attention', True, 'Enable Attention ?')
    flags.DEFINE_float('output_sampling', 0.1, 'Output Sampling Rate')
    flags.DEFINE_float('lstm_state_dropout', 0.1,
                       'Dropout applied to the h state of the LSTM')
    flags.DEFINE_string('decoder_initialisation', 'final_encoder_state',
                        'Decoder initialisation scheme')
    flags.DEFINE_string('segmental_variant', 'v1', 'Segmental model variant')

    # Transformer Model
    flags.DEFINE_integer('transformer_hidden_size', 256,
                         'State size of the Transformer layers')
    flags.DEFINE_integer('transformer_num_encoder_layers', 6,
                         'Number of layers in the encoder stack')
    flags.DEFINE_integer('transformer_num_decoder_layers', 6,
                         'Number of layers in the decoder stack')
    flags.DEFINE_integer('transformer_num_heads', 1,
                         'Number of attention_heads')
    flags.DEFINE_integer('transformer_filter_size', 256, 'Filter size')
    flags.DEFINE_float('transformer_relu_dropout', 0.1, 'Filter size')
    flags.DEFINE_float('transformer_attention_dropout', 0.1, 'Filter size')
    flags.DEFINE_float('transformer_layer_postprocess_dropout', 0.1,
                       'Post-processing layer dropout')
    flags.DEFINE_string('transformer_dtype', 'float32', 'Data type')
    flags.DEFINE_integer('transformer_extra_decode_length', 0,
                         'Extra Decode Length')
    flags.DEFINE_integer('transformer_beam_size', 10, 'Beam search width')
    flags.DEFINE_float('transformer_alpha', 0.6,
                       'Used for length normalisation in beam search')
    flags.DEFINE_boolean('transformer_online_encoder', False,
                         'Whether or not to use a causal attention bias')
    flags.DEFINE_integer('transformer_encoder_lookahead', 11,
                         'Number of frames for encoder attention lookahead')
    flags.DEFINE_integer('transformer_encoder_lookback', 11,
                         'Number of frames for encoder attention lookback')
    flags.DEFINE_boolean('transformer_online_decoder', False,
                         'Wheter or not to use online decoding')
    flags.DEFINE_integer(
        'transformer_decoder_lookahead', 5,
        'Number of segments for cross-modal attention lookahead')
    flags.DEFINE_integer(
        'transformer_decoder_lookback', 5,
        'Number of segments for cross-modal attention lookback')
    flags.DEFINE_float('transformer_l1_regularisation', 0.0,
                       'Transformer L1 weight regularisation')
    flags.DEFINE_float('transformer_l2_regularisation', 0.0,
                       'Transformer L2 weight regularisation')

    ## Experimental flags
    flags.DEFINE_integer('transformer_num_avalign_layers', 1,
                         'Number of layers in the AVAlign stack')
    flags.DEFINE_boolean('au_loss', False,
                         'Use the Action Unit Loss for multimodal encoders')
    flags.DEFINE_float('au_loss_weight', 10.0,
                       'Scalar multiplier of the AU loss')
    flags.DEFINE_boolean(
        'constrain_av_attention', False,
        'Limit the search space of the A-V alignment'
        ' to a window centred on the main diagonal')
    flags.DEFINE_integer('constrain_frames', 2, 'Num frames')
    flags.DEFINE_boolean('word_loss', False, 'Word counting loss')
    flags.DEFINE_float('wordloss_weight', 0.01, 'Num words loss')

    flags.DEFINE_string('wb_activation', 'sigmoid',
                        'Activation for the encoder halting unit')
Exemplo n.º 16
0
FLAGS = flags.FLAGS

flags.DEFINE_string('results_patern',
                    default=None,
                    help=('File patern for experiments results files'))

flags.DEFINE_string('charts_path',
                    default=None,
                    help=('Path where the charts will be saved on'))

flags.DEFINE_list('chart_format_list',
                  default=['pdf'],
                  help=('List of file formats to save charts on'))

flags.DEFINE_multi_float('log_prec_at_pr',
                         default=None,
                         help=('Log Precision at Recall levels'))

flags.DEFINE_string(
    'results_metrics_csv_name',
    default=None,
    help=('Name to the csv file where metrics will be saved on'))

flags.DEFINE_bool('show_random_guess',
                  default=False,
                  help=('Show line for random guess'))

flags.mark_flag_as_required('results_patern')
flags.mark_flag_as_required('charts_path')

ClassifierResults = collections.namedtuple("ClassifierResults", [