Exemple #1
0
def create_models():
    ali = create_model_brick()
    x = tensor.tensor4('features')
    z = ali.theano_rng.normal(size=(x.shape[0], NLAT, 1, 1))

    def _create_model(with_dropout):
        cg = ComputationGraph(ali.compute_losses(x, z))
        if with_dropout:
            inputs = VariableFilter(
                bricks=([ali.discriminator.x_discriminator.layers[0],
                         ali.discriminator.z_discriminator.layers[0]]),
                roles=[INPUT])(cg.variables)
            cg = apply_dropout(cg, inputs, 0.2)
            inputs = VariableFilter(
                bricks=(ali.discriminator.x_discriminator.layers[2::3] +
                        ali.discriminator.z_discriminator.layers[2::2] +
                        ali.discriminator.joint_discriminator.layers[::2]),
                roles=[INPUT])(cg.variables)
            cg = apply_dropout(cg, inputs, 0.5)
        return Model(cg.outputs)

    model = _create_model(with_dropout=False)
    with batch_normalization(ali):
        bn_model = _create_model(with_dropout=True)

    pop_updates = list(
        set(get_batch_normalization_updates(bn_model, allow_duplicates=True)))
    bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates]

    return model, bn_model, bn_updates
Exemple #2
0
def create_models():
    ali = create_model_brick()
    x = tensor.tensor4('features')
    z = ali.theano_rng.normal(size=(x.shape[0], NLAT, 1, 1))

    def _create_model(with_dropout):
        cg = ComputationGraph(ali.compute_losses(x, z))
        if with_dropout:
            inputs = VariableFilter(
                bricks=([ali.discriminator.x_discriminator.layers[0]] +
                        ali.discriminator.x_discriminator.layers[2::3] +
                        ali.discriminator.z_discriminator.layers[::2] +
                        ali.discriminator.joint_discriminator.layers[::2]),
                roles=[INPUT])(cg.variables)
            cg = apply_dropout(cg, inputs, 0.2)
        return Model(cg.outputs)

    model = _create_model(with_dropout=False)
    with batch_normalization(ali):
        bn_model = _create_model(with_dropout=True)

    pop_updates = list(
        set(get_batch_normalization_updates(bn_model, allow_duplicates=True)))
    bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates]

    return model, bn_model, bn_updates
Exemple #3
0
def create_models():
    gan = create_model_brick()
    x = tensor.matrix('features')
    z = gan.theano_rng.normal(size=(x.shape[0], NLAT))

    def _create_model(with_dropout):
        cg = ComputationGraph(gan.compute_losses(x, z))
        if with_dropout:
            inputs = VariableFilter(bricks=gan.discriminator.children[1:],
                                    roles=[INPUT])(cg.variables)
            cg = apply_dropout(cg, inputs, 0.5)
            inputs = VariableFilter(bricks=[gan.discriminator],
                                    roles=[INPUT])(cg.variables)
            cg = apply_dropout(cg, inputs, 0.2)
        return Model(cg.outputs)

    model = _create_model(with_dropout=False)
    with batch_normalization(gan):
        bn_model = _create_model(with_dropout=False)

    pop_updates = list(
        set(get_batch_normalization_updates(bn_model, allow_duplicates=True)))
    bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates]

    return model, bn_model, bn_updates
Exemple #4
0
def create_models():
    gan = create_model_brick()
    x = tensor.matrix('features')
    z = gan.theano_rng.normal(size=(x.shape[0], NLAT))

    def _create_model(with_dropout):
        cg = ComputationGraph(gan.compute_losses(x, z))
        if with_dropout:
            inputs = VariableFilter(
                bricks=gan.discriminator.children[1:],
                roles=[INPUT])(cg.variables)
            cg = apply_dropout(cg, inputs, 0.5)
            inputs = VariableFilter(
                bricks=[gan.discriminator],
                roles=[INPUT])(cg.variables)
            cg = apply_dropout(cg, inputs, 0.2)
        return Model(cg.outputs)

    model = _create_model(with_dropout=False)
    with batch_normalization(gan):
        bn_model = _create_model(with_dropout=False)

    pop_updates = list(
        set(get_batch_normalization_updates(bn_model, allow_duplicates=True)))
    bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates]

    return model, bn_model, bn_updates
Exemple #5
0
    def create_models(self):
        gan = self.create_model_brick()
        x = tensor.matrix('features')
        zs = []
        for i in range(self._config["num_packing"]):
            z = circle_gaussian_mixture(num_modes=self._config["num_zmode"],
                                        num_samples=x.shape[0],
                                        dimension=self._config["num_zdim"],
                                        r=self._config["z_mode_r"],
                                        std=self._config["z_mode_std"])
            zs.append(z)

        def _create_model(with_dropout):
            cg = ComputationGraph(gan.compute_losses(x, zs))
            if with_dropout:
                inputs = VariableFilter(bricks=gan.discriminator.children[1:],
                                        roles=[INPUT])(cg.variables)
                cg = apply_dropout(cg, inputs, 0.5)
                inputs = VariableFilter(bricks=[gan.discriminator],
                                        roles=[INPUT])(cg.variables)
                cg = apply_dropout(cg, inputs, 0.2)
            return Model(cg.outputs)

        model = _create_model(with_dropout=False)
        with batch_normalization(gan):
            bn_model = _create_model(with_dropout=False)

        pop_updates = list(
            set(
                get_batch_normalization_updates(bn_model,
                                                allow_duplicates=True)))

        # merge same variables
        names = []
        counts = []
        pop_update_merges = []
        pop_update_merges_finals = []
        for pop_update in pop_updates:
            b = False
            for i in range(len(names)):
                if (pop_update[0].auto_name == names[i]):
                    counts[i] += 1
                    pop_update_merges[i][1] += pop_update[1]
                    b = True
                    break
            if not b:
                names.append(pop_update[0].auto_name)
                counts.append(1)
                pop_update_merges.append([pop_update[0], pop_update[1]])
        for i in range(len(pop_update_merges)):
            pop_update_merges_finals.append(
                (pop_update_merges[i][0], pop_update_merges[i][1] / counts[i]))

        bn_updates = [(p, m * 0.05 + p * 0.95)
                      for p, m in pop_update_merges_finals]

        return model, bn_model, bn_updates
Exemple #6
0
def main(mode, save_to, num_epochs, load_params=None,
         feature_maps=None, mlp_hiddens=None,
         conv_sizes=None, pool_sizes=None, stride=None, repeat_times=None,
         batch_size=None, num_batches=None, algo=None,
         test_set=None, valid_examples=None,
         dropout=None, max_norm=None, weight_decay=None,
         batch_norm=None):
    if feature_maps is None:
        feature_maps = [20, 50, 50]
    if mlp_hiddens is None:
        mlp_hiddens = [500]
    if conv_sizes is None:
        conv_sizes = [5, 5, 5]
    if pool_sizes is None:
        pool_sizes = [2, 2, 2]
    if repeat_times is None:
        repeat_times = [1, 1, 1]
    if batch_size is None:
        batch_size = 500
    if valid_examples is None:
        valid_examples = 2500
    if stride is None:
        stride = 1
    if test_set is None:
        test_set = 'test'
    if algo is None:
        algo = 'rmsprop'
    if batch_norm is None:
        batch_norm = False

    image_size = (128, 128)
    output_size = 2

    if (len(feature_maps) != len(conv_sizes) or
        len(feature_maps) != len(pool_sizes) or
        len(feature_maps) != len(repeat_times)):
        raise ValueError("OMG, inconsistent arguments")

    # Use ReLUs everywhere and softmax for the final prediction
    conv_activations = [Rectifier() for _ in feature_maps]
    mlp_activations = [Rectifier() for _ in mlp_hiddens] + [Softmax()]
    convnet = LeNet(conv_activations, 3, image_size,
                    stride=stride,
                    filter_sizes=zip(conv_sizes, conv_sizes),
                    feature_maps=feature_maps,
                    pooling_sizes=zip(pool_sizes, pool_sizes),
                    repeat_times=repeat_times,
                    top_mlp_activations=mlp_activations,
                    top_mlp_dims=mlp_hiddens + [output_size],
                    border_mode='full',
                    batch_norm=batch_norm,
                    weights_init=Glorot(),
                    biases_init=Constant(0))
    # We push initialization config to set different initialization schemes
    # for convolutional layers.
    convnet.initialize()
    logging.info("Input dim: {} {} {}".format(
        *convnet.children[0].get_dim('input_')))
    for i, layer in enumerate(convnet.layers):
        if isinstance(layer, Activation):
            logging.info("Layer {} ({})".format(
                i, layer.__class__.__name__))
        else:
            logging.info("Layer {} ({}) dim: {} {} {}".format(
                i, layer.__class__.__name__, *layer.get_dim('output')))


    single_x = tensor.tensor3('image_features')
    x = tensor.tensor4('image_features')
    single_y = tensor.lvector('targets')
    y = tensor.lmatrix('targets')

    # Training
    with batch_normalization(convnet):
        probs = convnet.apply(x)
    cost = (CategoricalCrossEntropy().apply(y.flatten(), probs)
            .copy(name='cost'))
    error_rate = (MisclassificationRate().apply(y.flatten(), probs)
                  .copy(name='error_rate'))

    cg = ComputationGraph([cost, error_rate])
    extra_updates = []

    if batch_norm: # batch norm:
        logger.debug("Apply batch norm")
        pop_updates = get_batch_normalization_updates(cg)
        # p stands for population mean
        # m stands for minibatch
        alpha = 0.005
        extra_updates = [(p, m * alpha + p * (1 - alpha))
                         for p, m in pop_updates]
        population_statistics = [p for p, m in extra_updates]
    if dropout:
        relu_outputs = VariableFilter(bricks=[Rectifier], roles=[OUTPUT])(cg)
        cg = apply_dropout(cg, relu_outputs, dropout)
    cost, error_rate = cg.outputs
    if weight_decay:
        logger.debug("Apply weight decay {}".format(weight_decay))
        cost += weight_decay * l2_norm(cg.parameters)
        cost.name = 'cost'

    # Validation
    valid_probs = convnet.apply_5windows(single_x)
    valid_cost = (CategoricalCrossEntropy().apply(single_y, valid_probs)
            .copy(name='cost'))
    valid_error_rate = (MisclassificationRate().apply(
        single_y, valid_probs).copy(name='error_rate'))

    model = Model([cost, error_rate])
    if load_params:
        logger.info("Loaded params from {}".format(load_params))
        with open(load_params, 'r') as src:
            model.set_parameter_values(load_parameters(src))

    # Training stream with random cropping
    train = DogsVsCats(("train",), subset=slice(None, 25000 - valid_examples, None))
    train_str =  DataStream(
        train, iteration_scheme=ShuffledScheme(train.num_examples, batch_size))
    train_str = add_transformers(train_str, random_crop=True)

    # Validation stream without cropping
    valid = DogsVsCats(("train",), subset=slice(25000 - valid_examples, None, None))
    valid_str = DataStream(
        valid, iteration_scheme=SequentialExampleScheme(valid.num_examples))
    valid_str = add_transformers(valid_str)

    if mode == 'train':
        directory, _ = os.path.split(sys.argv[0])
        env = dict(os.environ)
        env['THEANO_FLAGS'] = 'floatX=float32'
        port = numpy.random.randint(1025, 10000)
        server = subprocess.Popen(
            [directory + '/server.py',
             str(25000 - valid_examples), str(batch_size), str(port)],
            env=env, stderr=subprocess.STDOUT)
        train_str = ServerDataStream(
            ('image_features', 'targets'), produces_examples=False,
            port=port)

        save_to_base, save_to_extension = os.path.splitext(save_to)

        # Train with simple SGD
        if algo == 'rmsprop':
            step_rule = RMSProp(decay_rate=0.999, learning_rate=0.0003)
        elif algo == 'adam':
            step_rule = Adam()
        else:
            assert False
        if max_norm:
            conv_params = VariableFilter(bricks=[Convolutional], roles=[WEIGHT])(cg)
            linear_params = VariableFilter(bricks=[Linear], roles=[WEIGHT])(cg)
            step_rule = CompositeRule(
                [step_rule,
                 Restrict(VariableClipping(max_norm, axis=0), linear_params),
                 Restrict(VariableClipping(max_norm, axis=(1, 2, 3)), conv_params)])

        algorithm = GradientDescent(
            cost=cost, parameters=model.parameters,
            step_rule=step_rule)
        algorithm.add_updates(extra_updates)
        # `Timing` extension reports time for reading data, aggregating a batch
        # and monitoring;
        # `ProgressBar` displays a nice progress bar during training.
        extensions = [Timing(every_n_batches=100),
                    FinishAfter(after_n_epochs=num_epochs,
                                after_n_batches=num_batches),
                    DataStreamMonitoring(
                        [valid_cost, valid_error_rate],
                        valid_str,
                        prefix="valid"),
                    TrainingDataMonitoring(
                        [cost, error_rate,
                        aggregation.mean(algorithm.total_gradient_norm)],
                        prefix="train",
                        after_epoch=True),
                    TrackTheBest("valid_error_rate"),
                    Checkpoint(save_to, save_separately=['log'],
                               parameters=cg.parameters +
                               (population_statistics if batch_norm else []),
                               before_training=True, after_epoch=True)
                        .add_condition(
                            ['after_epoch'],
                            OnLogRecord("valid_error_rate_best_so_far"),
                            (save_to_base + '_best' + save_to_extension,)),
                    Printing(every_n_batches=100)]

        model = Model(cost)

        main_loop = MainLoop(
            algorithm,
            train_str,
            model=model,
            extensions=extensions)
        try:
            main_loop.run()
        finally:
            server.terminate()
    elif mode == 'test':
        classify = theano.function([single_x], valid_probs.argmax())
        test = DogsVsCats((test_set,))
        test_str = DataStream(
            test, iteration_scheme=SequentialExampleScheme(test.num_examples))
        test_str = add_transformers(test_str)
        correct = 0
        with open(save_to, 'w') as dst:
            print("id", "label", sep=',', file=dst)
            for index, example in enumerate(test_str.get_epoch_iterator()):
                image = example[0]
                prediction = classify(image)
                print(index + 1, classify(image), sep=',', file=dst)
                if len(example) > 1 and prediction == example[1]:
                    correct += 1
        print(correct / float(test.num_examples))
    else:
        assert False
Exemple #7
0
def train_snli_model(new_training_job,
                     config,
                     save_path,
                     params,
                     fast_start,
                     fuel_server,
                     seed,
                     model='simple'):
    if config['exclude_top_k'] > config['num_input_words'] and config[
            'num_input_words'] > 0:
        raise Exception("Some words have neither word nor def embedding")
    c = config
    logger = configure_logger(name="snli_baseline_training",
                              log_file=os.path.join(save_path, "log.txt"))
    if not os.path.exists(save_path):
        logger.info("Start a new job")
        os.mkdir(save_path)
    else:
        logger.info("Continue an existing job")
    with open(os.path.join(save_path, "cmd.txt"), "w") as f:
        f.write(" ".join(sys.argv))

    # Make data paths nice
    for path in [
            'dict_path', 'embedding_def_path', 'embedding_path', 'vocab',
            'vocab_def', 'vocab_text'
    ]:
        if c.get(path, ''):
            if not os.path.isabs(c[path]):
                c[path] = os.path.join(fuel.config.data_path[0], c[path])

    main_loop_path = os.path.join(save_path, 'main_loop.tar')
    main_loop_best_val_path = os.path.join(save_path, 'main_loop_best_val.tar')
    stream_path = os.path.join(save_path, 'stream.pkl')

    # Save config to save_path
    json.dump(config, open(os.path.join(save_path, "config.json"), "w"))

    if model == 'simple':
        nli_model, data, used_dict, used_retrieval, _ = _initialize_simple_model_and_data(
            c)
    elif model == 'esim':
        nli_model, data, used_dict, used_retrieval, _ = _initialize_esim_model_and_data(
            c)
    else:
        raise NotImplementedError()

    # Compute cost
    s1, s2 = T.lmatrix('sentence1'), T.lmatrix('sentence2')

    if c['dict_path']:
        assert os.path.exists(c['dict_path'])
        s1_def_map, s2_def_map = T.lmatrix('sentence1_def_map'), T.lmatrix(
            'sentence2_def_map')
        def_mask = T.fmatrix("def_mask")
        defs = T.lmatrix("defs")
    else:
        s1_def_map, s2_def_map = None, None
        def_mask = None
        defs = None

    s1_mask, s2_mask = T.fmatrix('sentence1_mask'), T.fmatrix('sentence2_mask')
    y = T.ivector('label')

    cg = {}
    for train_phase in [True, False]:
        # NOTE: Please don't change outputs of cg
        if train_phase:
            with batch_normalization(nli_model):
                pred = nli_model.apply(s1,
                                       s1_mask,
                                       s2,
                                       s2_mask,
                                       def_mask=def_mask,
                                       defs=defs,
                                       s1_def_map=s1_def_map,
                                       s2_def_map=s2_def_map,
                                       train_phase=train_phase)
        else:
            pred = nli_model.apply(s1,
                                   s1_mask,
                                   s2,
                                   s2_mask,
                                   def_mask=def_mask,
                                   defs=defs,
                                   s1_def_map=s1_def_map,
                                   s2_def_map=s2_def_map,
                                   train_phase=train_phase)

        cost = CategoricalCrossEntropy().apply(y.flatten(), pred)
        error_rate = MisclassificationRate().apply(y.flatten(), pred)
        cg[train_phase] = ComputationGraph([cost, error_rate])

    # Weight decay (TODO: Make it less bug prone)
    if model == 'simple':
        weights_to_decay = VariableFilter(
            bricks=[dense for dense, relu, bn in nli_model._mlp],
            roles=[WEIGHT])(cg[True].variables)
        weight_decay = np.float32(c['l2']) * sum(
            (w**2).sum() for w in weights_to_decay)
    elif model == 'esim':
        weight_decay = 0.0
    else:
        raise NotImplementedError()

    final_cost = cg[True].outputs[0] + weight_decay
    final_cost.name = 'final_cost'

    # Add updates for population parameters

    if c.get("bn", True):
        pop_updates = get_batch_normalization_updates(cg[True])
        extra_updates = [(p, m * 0.1 + p * (1 - 0.1)) for p, m in pop_updates]
    else:
        pop_updates = []
        extra_updates = []

    if params:
        logger.debug("Load parameters from {}".format(params))
        with open(params) as src:
            loaded_params = load_parameters(src)
            cg[True].set_parameter_values(loaded_params)
            for param, m in pop_updates:
                param.set_value(loaded_params[get_brick(
                    param).get_hierarchical_name(param)])

    if os.path.exists(os.path.join(save_path, "main_loop.tar")):
        logger.warning("Manually loading BN stats :(")
        with open(os.path.join(save_path, "main_loop.tar")) as src:
            loaded_params = load_parameters(src)

        for param, m in pop_updates:
            param.set_value(
                loaded_params[get_brick(param).get_hierarchical_name(param)])

    if theano.config.compute_test_value != 'off':
        test_value_data = next(
            data.get_stream('train', batch_size=4).get_epoch_iterator())
        s1.tag.test_value = test_value_data[0]
        s1_mask.tag.test_value = test_value_data[1]
        s2.tag.test_value = test_value_data[2]
        s2_mask.tag.test_value = test_value_data[3]
        y.tag.test_value = test_value_data[4]

    # Freeze embeddings
    if not c['train_emb']:
        frozen_params = [
            p for E in nli_model.get_embeddings_lookups() for p in E.parameters
        ]
        train_params = [p for p in cg[True].parameters]
        assert len(set(frozen_params) & set(train_params)) > 0
    else:
        frozen_params = []
    if not c.get('train_def_emb', 1):
        frozen_params_def = [
            p for E in nli_model.get_def_embeddings_lookups()
            for p in E.parameters
        ]
        train_params = [p for p in cg[True].parameters]
        assert len(set(frozen_params_def) & set(train_params)) > 0
        frozen_params += frozen_params_def
    train_params = [p for p in cg[True].parameters if p not in frozen_params]
    train_params_keys = [
        get_brick(p).get_hierarchical_name(p) for p in train_params
    ]

    # Optimizer
    algorithm = GradientDescent(cost=final_cost,
                                on_unused_sources='ignore',
                                parameters=train_params,
                                step_rule=Adam(learning_rate=c['lr']))
    algorithm.add_updates(extra_updates)
    m = Model(final_cost)

    parameters = m.get_parameter_dict()  # Blocks version mismatch
    logger.info("Trainable parameters" + "\n" +
                pprint.pformat([(key, parameters[key].get_value().shape)
                                for key in sorted(train_params_keys)],
                               width=120))
    logger.info("# of parameters {}".format(
        sum([
            np.prod(parameters[key].get_value().shape)
            for key in sorted(train_params_keys)
        ])))

    ### Monitored args ###
    train_monitored_vars = [final_cost] + cg[True].outputs
    monitored_vars = cg[False].outputs
    val_acc = monitored_vars[1]
    to_monitor_names = [
        'def_unk_ratio', 's1_merged_input_rootmean2', 's1_def_mean_rootmean2',
        's1_gate_rootmean2', 's1_compose_gate_rootmean2'
    ]
    for k in to_monitor_names:
        train_v, valid_v = VariableFilter(name=k)(
            cg[True]), VariableFilter(name=k)(cg[False])
        if len(train_v):
            logger.info("Adding {} tracking".format(k))
            train_monitored_vars.append(train_v[0])
            monitored_vars.append(valid_v[0])
        else:
            logger.warning("Didnt find {} in cg".format(k))

    if c['monitor_parameters']:
        for name in train_params_keys:
            param = parameters[name]
            num_elements = numpy.product(param.get_value().shape)
            norm = param.norm(2) / num_elements
            grad_norm = algorithm.gradients[param].norm(2) / num_elements
            step_norm = algorithm.steps[param].norm(2) / num_elements
            stats = tensor.stack(norm, grad_norm, step_norm,
                                 step_norm / grad_norm)
            stats.name = name + '_stats'
            train_monitored_vars.append(stats)

    regular_training_stream = data.get_stream('train',
                                              batch_size=c['batch_size'],
                                              seed=seed)

    if fuel_server:
        # the port will be configured by the StartFuelServer extension
        training_stream = ServerDataStream(
            sources=regular_training_stream.sources,
            hwm=100,
            produces_examples=regular_training_stream.produces_examples)
    else:
        training_stream = regular_training_stream

    ### Build extensions ###

    extensions = [
        # Load(main_loop_path, load_iteration_state=True, load_log=True)
        #     .set_conditions(before_training=not new_training_job),
        StartFuelServer(regular_training_stream,
                        stream_path,
                        hwm=100,
                        script_path=os.path.join(
                            os.path.dirname(__file__),
                            "../bin/start_fuel_server.py"),
                        before_training=fuel_server),
        Timing(every_n_batches=c['mon_freq']),
        ProgressBar(),
        RetrievalPrintStats(retrieval=used_retrieval,
                            every_n_batches=c['mon_freq_valid'],
                            before_training=not fast_start),
        Timestamp(),
        TrainingDataMonitoring(train_monitored_vars,
                               prefix="train",
                               every_n_batches=c['mon_freq']),
    ]

    if c['layout'] == 'snli':
        validation = DataStreamMonitoring(monitored_vars,
                                          data.get_stream('valid',
                                                          batch_size=14,
                                                          seed=seed),
                                          before_training=not fast_start,
                                          on_resumption=True,
                                          after_training=True,
                                          every_n_batches=c['mon_freq_valid'],
                                          prefix='valid')
        extensions.append(validation)
    elif c['layout'] == 'mnli':
        validation = DataStreamMonitoring(monitored_vars,
                                          data.get_stream('valid_matched',
                                                          batch_size=14,
                                                          seed=seed),
                                          every_n_batches=c['mon_freq_valid'],
                                          on_resumption=True,
                                          after_training=True,
                                          prefix='valid_matched')
        validation_mismatched = DataStreamMonitoring(
            monitored_vars,
            data.get_stream('valid_mismatched', batch_size=14, seed=seed),
            every_n_batches=c['mon_freq_valid'],
            before_training=not fast_start,
            on_resumption=True,
            after_training=True,
            prefix='valid_mismatched')
        extensions.extend([validation, validation_mismatched])
    else:
        raise NotImplementedError()

    # Similarity trackers for embeddings
    if len(c.get('vocab_def', '')):
        retrieval_vocab = Vocabulary(c['vocab_def'])
    else:
        retrieval_vocab = data.vocab

    retrieval_all = Retrieval(vocab_text=retrieval_vocab,
                              dictionary=used_dict,
                              max_def_length=c['max_def_length'],
                              exclude_top_k=0,
                              max_def_per_word=c['max_def_per_word'])

    for name in [
            's1_word_embeddings', 's1_dict_word_embeddings',
            's1_translated_word_embeddings'
    ]:
        variables = VariableFilter(name=name)(cg[False])
        if len(variables):
            s1_emb = variables[0]
            logger.info("Adding similarity tracking for " + name)
            # A bit sloppy about downcast

            if "dict" in name:
                embedder = construct_dict_embedder(theano.function(
                    [s1, defs, def_mask, s1_def_map],
                    s1_emb,
                    allow_input_downcast=True),
                                                   vocab=data.vocab,
                                                   retrieval=retrieval_all)
                extensions.append(
                    SimilarityWordEmbeddingEval(
                        embedder=embedder,
                        prefix=name,
                        every_n_batches=c['mon_freq_valid'],
                        before_training=not fast_start))
            else:
                embedder = construct_embedder(theano.function(
                    [s1], s1_emb, allow_input_downcast=True),
                                              vocab=data.vocab)
                extensions.append(
                    SimilarityWordEmbeddingEval(
                        embedder=embedder,
                        prefix=name,
                        every_n_batches=c['mon_freq_valid'],
                        before_training=not fast_start))

    track_the_best = TrackTheBest(validation.record_name(val_acc),
                                  before_training=not fast_start,
                                  every_n_epochs=c['save_freq_epochs'],
                                  after_training=not fast_start,
                                  every_n_batches=c['mon_freq_valid'],
                                  choose_best=min)
    extensions.append(track_the_best)

    # Special care for serializing embeddings
    if len(c.get('embedding_path', '')) or len(c.get('embedding_def_path',
                                                     '')):
        extensions.insert(
            0,
            LoadNoUnpickling(main_loop_path,
                             load_iteration_state=True,
                             load_log=True).set_conditions(
                                 before_training=not new_training_job))
        extensions.append(
            Checkpoint(main_loop_path,
                       parameters=train_params + [p for p, m in pop_updates],
                       save_main_loop=False,
                       save_separately=['log', 'iteration_state'],
                       before_training=not fast_start,
                       every_n_epochs=c['save_freq_epochs'],
                       after_training=not fast_start).add_condition(
                           ['after_batch', 'after_epoch'],
                           OnLogRecord(track_the_best.notification_name),
                           (main_loop_best_val_path, )))
    else:
        extensions.insert(
            0,
            Load(main_loop_path, load_iteration_state=True,
                 load_log=True).set_conditions(
                     before_training=not new_training_job))
        extensions.append(
            Checkpoint(main_loop_path,
                       parameters=cg[True].parameters +
                       [p for p, m in pop_updates],
                       before_training=not fast_start,
                       every_n_epochs=c['save_freq_epochs'],
                       after_training=not fast_start).add_condition(
                           ['after_batch', 'after_epoch'],
                           OnLogRecord(track_the_best.notification_name),
                           (main_loop_best_val_path, )))

    extensions.extend([
        DumpCSVSummaries(save_path,
                         every_n_batches=c['mon_freq_valid'],
                         after_training=True),
        DumpTensorflowSummaries(save_path,
                                after_epoch=True,
                                every_n_batches=c['mon_freq_valid'],
                                after_training=True),
        Printing(every_n_batches=c['mon_freq_valid']),
        PrintMessage(msg="save_path={}".format(save_path),
                     every_n_batches=c['mon_freq']),
        FinishAfter(after_n_batches=c['n_batches']).add_condition(
            ['after_batch'],
            OnLogStatusExceed('iterations_done', c['n_batches']))
    ])

    logger.info(extensions)

    ### Run training ###

    if "VISDOM_SERVER" in os.environ:
        print("Running visdom server")
        ret = subprocess.Popen([
            os.path.join(os.path.dirname(__file__), "../visdom_plotter.py"),
            "--visdom-server={}".format(os.environ['VISDOM_SERVER']),
            "--folder={}".format(save_path)
        ])
        time.sleep(0.1)
        if ret.returncode is not None:
            raise Exception()
        atexit.register(lambda: os.kill(ret.pid, signal.SIGINT))

    model = Model(cost)
    for p, m in pop_updates:
        model._parameter_dict[get_brick(p).get_hierarchical_name(p)] = p

    main_loop = MainLoop(algorithm,
                         training_stream,
                         model=model,
                         extensions=extensions)

    assert os.path.exists(save_path)
    main_loop.run()