예제 #1
0
def get_config():
    """Raw ImageNet config."""

    config = _get_config()

    config.model.model_type = "resnet18_features"

    return config
예제 #2
0
def get_config():
  """Goal classifier config."""

  config = _get_config()

  config.algorithm = "goal_classifier"
  config.optim.train_max_iters = 6_000
  config.frame_sampler.strategy = "last_and_randoms"
  config.frame_sampler.num_frames_per_sequence = 15
  config.model.model_type = "resnet18_classifier"
  config.model.normalize_embeddings = False
  config.model.learnable_temp = False

  return config
예제 #3
0
def get_config():
  """TCN config."""

  config = _get_config()

  config.algorithm = "tcn"
  config.optim.train_max_iters = 4_000
  config.frame_sampler.strategy = "window"
  config.frame_sampler.num_frames_per_sequence = 40
  config.model.model_type = "resnet18_linear"
  config.model.normalize_embeddigs = False
  config.model.learnable_temp = False
  config.loss.tcn.pos_radius = 1
  config.loss.tcn.neg_radius = 4
  config.loss.tcn.num_pairs = 2
  config.loss.tcn.margin = 1.0
  config.loss.tcn.temperature = 0.1

  return config
예제 #4
0
def get_config():
  """LIFS config."""

  config = _get_config()

  config.algorithm = "lifs"
  config.optim.train_max_iters = 8_000
  config.frame_sampler.strategy = "variable_strided"
  config.model.model_type = "resnet18_linear_ae"
  config.model.embedding_size = 32
  config.model.normalize_embeddings = False
  config.model.learnable_temp = False
  config.loss.lifs.temperature = 0.1
  config.eval.downstream_task_evaluators = [
      "reward_visualizer",
      "kendalls_tau",
      "reconstruction_visualizer",
  ]

  return config
예제 #5
0
def get_config():
    """TCC config."""

    config = _get_config()

    config.algorithm = "tcc"
    config.optim.train_max_iters = 4_000
    config.frame_sampler.strategy = "uniform"
    config.frame_sampler.uniform_sampler.offset = 0
    config.frame_sampler.num_frames_per_sequence = 40
    config.model.model_type = "resnet18_linear"
    config.model.embedding_size = 32
    config.model.normalize_embeddings = False
    config.model.learnable_temp = False
    config.loss.tcc.stochastic_matching = False
    config.loss.tcc.loss_type = "regression_mse"
    config.loss.tcc.similarity_type = "l2"
    config.loss.tcc.softmax_temperature = 1.0

    return config