def get_config(): """Raw ImageNet config.""" config = _get_config() config.model.model_type = "resnet18_features" return config
def get_config(): """Goal classifier config.""" config = _get_config() config.algorithm = "goal_classifier" config.optim.train_max_iters = 6_000 config.frame_sampler.strategy = "last_and_randoms" config.frame_sampler.num_frames_per_sequence = 15 config.model.model_type = "resnet18_classifier" config.model.normalize_embeddings = False config.model.learnable_temp = False return config
def get_config(): """TCN config.""" config = _get_config() config.algorithm = "tcn" config.optim.train_max_iters = 4_000 config.frame_sampler.strategy = "window" config.frame_sampler.num_frames_per_sequence = 40 config.model.model_type = "resnet18_linear" config.model.normalize_embeddigs = False config.model.learnable_temp = False config.loss.tcn.pos_radius = 1 config.loss.tcn.neg_radius = 4 config.loss.tcn.num_pairs = 2 config.loss.tcn.margin = 1.0 config.loss.tcn.temperature = 0.1 return config
def get_config(): """LIFS config.""" config = _get_config() config.algorithm = "lifs" config.optim.train_max_iters = 8_000 config.frame_sampler.strategy = "variable_strided" config.model.model_type = "resnet18_linear_ae" config.model.embedding_size = 32 config.model.normalize_embeddings = False config.model.learnable_temp = False config.loss.lifs.temperature = 0.1 config.eval.downstream_task_evaluators = [ "reward_visualizer", "kendalls_tau", "reconstruction_visualizer", ] return config
def get_config(): """TCC config.""" config = _get_config() config.algorithm = "tcc" config.optim.train_max_iters = 4_000 config.frame_sampler.strategy = "uniform" config.frame_sampler.uniform_sampler.offset = 0 config.frame_sampler.num_frames_per_sequence = 40 config.model.model_type = "resnet18_linear" config.model.embedding_size = 32 config.model.normalize_embeddings = False config.model.learnable_temp = False config.loss.tcc.stochastic_matching = False config.loss.tcc.loss_type = "regression_mse" config.loss.tcc.similarity_type = "l2" config.loss.tcc.softmax_temperature = 1.0 return config