示例#1
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')],
                         dtype=np.float32)
    probs /= probs.sum()

    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch,
                                probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features,
                        n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
            num_heads=opts.sender_num_heads,
            hidden_size=opts.sender_hidden,
            force_eos=opts.force_eos,
            generate_style=opts.sender_generate_style,
            causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features,
                        n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                         opts.vocab_size,
                                         opts.sender_embedding,
                                         opts.sender_hidden,
                                         cell=opts.sender_cell,
                                         max_len=opts.max_len,
                                         num_layers=opts.sender_num_layers,
                                         force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features,
                            n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.max_len,
            opts.receiver_embedding,
            opts.receiver_num_heads,
            opts.receiver_hidden,
            opts.receiver_num_layers,
            causal=opts.causal_receiver)
    else:

        receiver = Receiver(n_features=opts.n_features,
                            n_hidden=opts.receiver_hidden)

        if not opts.impatient:
            receiver = Receiver(n_features=opts.n_features,
                                n_hidden=opts.receiver_hidden)
            receiver = core.RnnReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.receiver_embedding,
                opts.receiver_hidden,
                cell=opts.receiver_cell,
                num_layers=opts.receiver_num_layers)
        else:
            receiver = Receiver(n_features=opts.receiver_hidden,
                                n_hidden=opts.vocab_size)
            # If impatient 1
            receiver = RnnReceiverImpatient(
                receiver,
                opts.vocab_size,
                opts.receiver_embedding,
                opts.receiver_hidden,
                cell=opts.receiver_cell,
                num_layers=opts.receiver_num_layers,
                max_len=opts.max_len,
                n_features=opts.n_features)

    sender.load_state_dict(
        torch.load(opts.sender_weights, map_location=torch.device('cpu')))
    receiver.load_state_dict(
        torch.load(opts.receiver_weights, map_location=torch.device('cpu')))

    if not opts.impatient:
        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff,
            length_cost=opts.length_cost,
            unigram_penalty=opts.unigram_pen)
    else:
        game = SenderImpatientReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff,
            length_cost=opts.length_cost,
            unigram_penalty=opts.unigram_pen)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])

    # Debut test position

    position_sieve = np.zeros((opts.n_features, opts.max_len))

    for position in range(opts.max_len):

        dataset = [[torch.eye(opts.n_features).to(device), None]]

        if opts.impatient:
            sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
                dump_test_position_impatient(trainer.game,
                                    dataset,
                                    position=position,
                                    voc_size=opts.vocab_size,
                                    gs=False,
                                    device=device,
                                    variable_length=True)
        else:
            sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
                dump_test_position(trainer.game,
                                    dataset,
                                    position=position,
                                    voc_size=opts.vocab_size,
                                    gs=False,
                                    device=device,
                                    variable_length=True)

        acc_pos = []

        for sender_input, message, receiver_output in zip(
                sender_inputs, messages, receiver_outputs):
            input_symbol = sender_input.argmax()
            output_symbol = receiver_output.argmax()
            acc = (input_symbol == output_symbol).float().item()
            acc_pos.append(acc)

        acc_pos = np.array(acc_pos)

        position_sieve[:, position] = acc_pos

    # Put -1 for position after message_length
    _, messages = dump(trainer.game, opts.n_features, device, False)

    # Convert messages to numpy array
    messages_np = []
    for x in messages:
        x = x.cpu().numpy()
        messages_np.append(x)

    for i in range(len(messages_np)):
        # Message i
        message_i = messages_np[i]
        id_0 = np.where(message_i == 0)[0]

        if id_0.shape[0] > 0:
            for j in range(id_0[0] + 1, opts.max_len):
                position_sieve[i, j] = -1

    np.save("analysis/position_sieve.npy", position_sieve)

    core.close()
示例#2
0
    neptune.init('anonymous/anonymous', backend=neptune.OfflineBackend())
    with neptune.create_experiment(params=vars(opts), upload_source_files=get_filepaths(), tags=['']) as experiment:

        # Pretraining game
        if opts.pretrain:
            pretraining_game = PretrainingmGameGS(pretrained_senders, receiver, padding=opts.padding)
            sender_params = [{'params': sender.parameters(), 'lr': opts.pretraining_sender_lr}
                             for sender in pretrained_senders]
            receiver_params = [{'params': receiver.parameters(), 'lr': opts.pretraining_receiver_lr}]
            optimizer = torch.optim.Adam(sender_params + receiver_params)
            trainer = core.Trainer(
                game=pretraining_game, optimizer=optimizer, train_data=train_loader,
                validation_data=test_loader,
                callbacks=[
                    CompositionalityMetricGS(full_dataset, pretrained_senders[0], opts, opts.vocab_size, prefix='1_'),
                    CompositionalityMetricGS(full_dataset, pretrained_senders[1], opts, opts.vocab_size, prefix='2_'),
                    NeptuneMonitor(prefix='pretrain'),
                    core.ConsoleLogger(print_train_loss=not opts.on_slurm),
                    EarlyStopperAccuracy(threshold=0.95, field_name='accuracy', delay=1, train=False),
                ])
            trainer.train(n_epochs=500_000)
            pretraining_game.train(False)

        # Compositional game
        compositional_game = CompositionalGameGS(sender_3, receiver)
        sender_params = [{'params': sender_3.parameters(), 'lr': opts.sender_lr}]
        receiver_params = [{'params': receiver.parameters(), 'lr': opts.receiver_lr}]
        optimizer = torch.optim.Adam(sender_params + receiver_params)
        trainer = core.Trainer(game=compositional_game, optimizer=optimizer, train_data=train_loader,
                               validation_data=test_loader,
                               callbacks=[
示例#3
0
                                  shuffle=True,
                                  opt=opts,
                                  batches_per_epoch=opts.batches_per_epoch,
                                  seed=None)
    validation_loader = ImagenetLoader(
        dataset,
        opt=opts,
        batch_size=opts.batch_size,
        batches_per_epoch=opts.batches_per_epoch,
        seed=7)
    game = get_my_game(opts)
    optimizer = core.build_optimizer(game.parameters())
    callback = None
    if opts.mode == 'gs':
        callbacks = [
            core.TemperatureUpdater(agent=game.sender, decay=0.9, minimum=0.1)
        ]
    else:
        callbacks = []

    callbacks.append(core.ConsoleLogger(as_json=True, print_train_loss=True))
    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#4
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')],
                         dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch,
                                probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    #################################
    # define sender (speaker) agent #
    #################################
    sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)
    sender = RnnSenderReinforce(sender,
                                opts.vocab_size,
                                opts.sender_embedding,
                                opts.sender_hidden,
                                cell=opts.sender_cell,
                                max_len=opts.max_len,
                                num_layers=opts.sender_num_layers,
                                noise_loc=opts.sender_noise_loc,
                                noise_scale=opts.sender_noise_scale)

    ####################################
    # define receiver (listener) agent #
    ####################################
    receiver = Receiver(n_features=opts.n_features,
                        n_hidden=opts.receiver_hidden)
    receiver = RnnReceiverDeterministic(receiver,
                                        opts.vocab_size,
                                        opts.receiver_embedding,
                                        opts.receiver_hidden,
                                        cell=opts.receiver_cell,
                                        num_layers=opts.receiver_num_layers,
                                        noise_loc=opts.receiver_noise_loc,
                                        noise_scale=opts.receiver_noise_scale)

    ###################
    # define  channel #
    ###################
    channel = Channel(vocab_size=opts.vocab_size, p=opts.channel_repl_prob)

    game = SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_symb_entropy_coeff=opts.sender_symb_entropy_coeff,
        sender_stop_entropy_coeff=opts.sender_stop_entropy_coeff,
        receiver_entropy_coeff=opts.receiver_entropy_coeff,
        length_cost=opts.length_cost,
        machineguntalk_cost=opts.machineguntalk_cost,
        channel=channel,
        sender_entropy_common_ratio=opts.sender_entropy_common_ratio)

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [
        EarlyStopperAccuracy(opts.early_stopping_thr),
        core.ConsoleLogger(as_json=True, print_train_loss=True)
    ]

    if opts.checkpoint_dir:
        '''
        info in checkpoint_name:
            - n_features as f
            - vocab_size as vocab
            - random_seed as rs
            - lr as lr
            - sender_hidden as shid
            - receiver_hidden as rhid
            - sender_symb_entropy_coeff as symbsentr
            - sender_stop_entropy_coeff as stopsentr
            - length_cost as reg
            - max_len as max_len
            - sender_noise_scale as sscl
            - receiver_noise_scale as rscl
            - channel_repl_prob as crp
            - sender_entropy_common_ratio as scr
        '''
        checkpoint_name = (
            f'{opts.name}' + ('_uniform' if opts.probs == 'uniform' else '') +
            f'_f{opts.n_features}' + f'_vocab{opts.vocab_size}' +
            f'_rs{opts.random_seed}' + f'_lr{opts.lr}' +
            f'_shid{opts.sender_hidden}' + f'_rhid{opts.receiver_hidden}' +
            f'_symbsentr{opts.sender_symb_entropy_coeff}' +
            f'_stopsentr{opts.sender_stop_entropy_coeff}' +
            f'_reg{opts.length_cost}' + f'_max_len{opts.max_len}' +
            f'_sscl{opts.sender_noise_scale}' +
            f'_rscl{opts.receiver_noise_scale}' +
            f'_crp{opts.channel_repl_prob}' +
            f'_scr{opts.sender_entropy_common_ratio}')
        callbacks.append(
            core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir,
                                 checkpoint_freq=opts.checkpoint_freq,
                                 prefix=checkpoint_name))

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)

    print('-- suffix test --')
    suffix_test(trainer.game, opts.n_features, device)
    print('-- dump --')
    dump(trainer.game, opts.n_features, device)
    core.close()
示例#5
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
    test_dataset = datasets.MNIST('./data',
                                  train=False,
                                  download=False,
                                  transform=transform)
    n_classes = 10

    corrupt_labels_(dataset=train_dataset,
                    p_corrupt=opts.p_corrupt,
                    seed=opts.random_seed + 1)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=opts.batch_size,
                                              shuffle=False,
                                              **kwargs)

    deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1
    deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1

    sender = Sender(vocab_size=opts.vocab_size,
                    deeper=deeper_alice,
                    linear_channel=opts.linear_channel == 1,
                    softmax_channel=opts.softmax_non_linearity == 1)
    receiver = Receiver(vocab_size=opts.vocab_size,
                        n_classes=n_classes,
                        deeper=deeper_bob)

    if opts.softmax_non_linearity != 1 and opts.linear_channel != 1:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           as_json=True,
                           early_stopping=early_stopper,
                           print_train_loss=True)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#6
0
        receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding,
                                                 opts.receiver_hidden, cell=opts.receiver_cell)

        game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff)
        callback = None
    elif opts.mode.lower() == 'gs':
        sender = core.RnnSenderGS(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                  cell=opts.sender_cell, max_len=opts.max_len, temperature=opts.temperature,
                                  force_eos=opts.force_eos)

        receiver = core.RnnReceiverGS(receiver, opts.vocab_size, opts.receiver_embedding,
                    opts.receiver_hidden, cell=opts.receiver_cell)

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        callback = sender.update_temp(0.9, 0.1)
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([
        {'params': game.sender.parameters(), 'lr': opts.sender_lr},
        {'params': game.receiver.parameters(), 'lr': opts.receiver_lr}
    ])

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader, epoch_callback=callback, print_train_loss=True)
    trainer.train(n_epochs=opts.n_epochs)

    core.close()

示例#7
0
文件: train.py 项目: DebasishMaji/EGG
                                    opts.receiver_embedding,
                                    opts.receiver_hidden,
                                    cell=opts.receiver_cell
                                    )

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        callback = sender.update_temp(0.9, 0.1)
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([
        {'params': game.sender.parameters(), 'lr': opts.sender_lr},
        {'params': game.receiver.parameters(), 'lr': opts.receiver_lr}
    ])

    trainer = core.Trainer(game=game, optimizer=optimizer,
                           train_data=train_data, validation_data=validation_data, epoch_callback=callback, as_json=opts.output_json)
    trainer.train(n_epochs=opts.n_epochs)

    if opts.evaluate:
        is_gs = opts.mode == 'gs'
        sender_inputs, messages, receiver_inputs, receiver_outputs, labels = core.dump_sender_receiver(game, test_data, is_gs, variable_length=True, device=device)

        receiver_outputs = move_to(receiver_outputs, device)
        labels = move_to(labels, device)

        receiver_outputs = torch.stack(receiver_outputs)
        labels = torch.stack(labels)

        tensor_accuracy = receiver_outputs.argmax(dim=1) == labels
        accuracy = torch.mean(tensor_accuracy.float()).item()
示例#8
0
                              opts.sender_hidden,
                              cell=opts.sender_cell,
                              max_len=opts.max_len,
                              temperature=opts.temperature)

    receiver = Receiver(opts.receiver_hidden)
    receiver = core.RnnReceiverGS(receiver,
                                  opts.vocab_size,
                                  opts.receiver_embedding,
                                  opts.receiver_hidden,
                                  cell=opts.receiver_cell)

    game = core.SenderReceiverRnnGS(sender, receiver, loss)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader)
    trainer.train(n_epochs=opts.n_epochs)

    sender_inputs, messages, _, receiver_outputs, labels = \
        core.dump_sender_receiver(game, test_loader, gs=True, device=device, variable_length=True)

    for (seq, l), message, output, label in zip(sender_inputs, messages,
                                                receiver_outputs, labels):
        print(f'{seq[:l]} -> {message} -> {output.argmax()} (label = {label})')

    core.close()
示例#9
0
def main(params):
    import copy
    opts = get_params(params)
    device = opts.device

    full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
    if opts.density_data > 0:
        sampled_data = select_subset_V2(full_data, opts.density_data,
                                        opts.n_attributes, opts.n_values)
        full_data = copy.deepcopy(sampled_data)

    train, generalization_holdout = split_holdout(full_data)
    train, uniform_holdout = split_train_test(train, 0.1)

    generalization_holdout, train, uniform_holdout, full_data = [
        one_hotify(x, opts.n_attributes, opts.n_values)
        for x in [generalization_holdout, train, uniform_holdout, full_data]
    ]

    train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(
        train, 1)

    generalization_holdout, uniform_holdout, full_data = ScaledDataset(
        generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset(
            full_data)
    generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [
        DataLoader(x, batch_size=opts.batch_size)
        for x in [generalization_holdout, uniform_holdout, full_data]
    ]

    train_loader = DataLoader(train, batch_size=opts.batch_size)
    validation_loader = DataLoader(validation, batch_size=len(validation))

    n_dim = opts.n_attributes * opts.n_values

    if opts.receiver_cell in ['lstm', 'rnn', 'gru']:
        receiver = MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.MMRnnReceiverDeterministic(receiver,
                                                   opts.vocab_size + 1,
                                                   opts.receiver_emb,
                                                   opts.receiver_hidden,
                                                   cell=opts.receiver_cell)
    else:
        raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}')

    if opts.sender_cell in ['lstm', 'rnn', 'gru']:
        sender = MMSender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        s1 = SplitWrapper(sender, 0)
        s2 = SplitWrapper(sender, 1)
        sender1 = core.RnnSenderReinforce(agent=s1,
                                          vocab_size=opts.vocab_size,
                                          embed_dim=opts.sender_emb,
                                          hidden_size=opts.sender_hidden,
                                          max_len=opts.max_len,
                                          force_eos=False,
                                          cell=opts.sender_cell)
        sender1 = PlusNWrapper(sender1, 1)
        sender2 = core.RnnSenderReinforce(agent=s2,
                                          vocab_size=opts.vocab_size,
                                          embed_dim=opts.sender_emb,
                                          hidden_size=opts.sender_hidden,
                                          max_len=opts.max_len,
                                          force_eos=False,
                                          cell=opts.sender_cell)
        # sender2 = PlusNWrapper(sender2, opts.vocab_size + 1)
        sender2 = PlusNWrapper(sender2, 1)
        sender = CombineMMRnnSenderReinforce(sender1, sender2)
    else:
        raise ValueError(f'Unknown sender cell, {opts.sender_cell}')

    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        'no': core.baselines.NoBaseline,
        'mean': core.baselines.MeanBaseline,
        'builtin': core.baselines.BuiltInBaseline
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline)
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    metrics_evaluator = MMMetrics(validation.examples,
                                  opts.device,
                                  opts.n_attributes,
                                  opts.n_values,
                                  opts.vocab_size + 1,
                                  freq=opts.stats_freq)

    loaders = []
    loaders.append(("generalization hold out", generalization_holdout_loader,
                    DiffLoss(opts.n_attributes,
                             opts.n_values,
                             generalization=True)))
    loaders.append(("uniform holdout", uniform_holdout_loader,
                    DiffLoss(opts.n_attributes, opts.n_values)))

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                         validation=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=False),
                               early_stopper, metrics_evaluator,
                               holdout_evaluator
                           ])
    trainer.train(n_epochs=opts.n_epochs)

    validation_acc = early_stopper.validation_stats[-1][1]['acc']
    uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc']

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random
            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(sender,
                                                   receiver,
                                                   loss,
                                                   sender_entropy_coeff=0.0,
                                                   receiver_entropy_coeff=0.0)
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                                 validation=True)

            trainer = core.Trainer(game=game,
                                   optimizer=optimizer,
                                   train_data=train_loader,
                                   validation_data=validation_loader,
                                   callbacks=[
                                       early_stopper,
                                       Evaluator(loaders, opts.device, freq=0)
                                   ])
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]['acc'] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru')

        def small_gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(
                MMReceiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru')

        def tiny_gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(
                MMReceiver(n_hidden=25, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=25, cell='gru')

        def nonlinear_receiver_generator():            return \
MMNonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1,
            max_length=opts.max_len, n_hidden=opts.receiver_hidden)

        for name, receiver_generator in [
            ('gru', gru_receiver_generator),
            ('nonlinear', nonlinear_receiver_generator),
            ('tiny_gru', tiny_gru_receiver_generator),
            ('small_gru', small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps({
                        "mode": "reset",
                        "seed": seed,
                        "receiver_name": name,
                        "auc": auc
                    }))

    print('---End--')

    core.close()
示例#10
0
文件: train.py 项目: DebasishMaji/EGG
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=True, download=True, transform=transform),
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=False, transform=transform),
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)

    n_classes = 10

    binarize = False

    test_loader = SplitImages(TakeFirstLoader(test_loader, n=1),
                              rows_receiver=opts.receiver_rows,
                              rows_sender=opts.sender_rows,
                              binarize=binarize,
                              receiver_bottom=True)
    train_loader = SplitImages(train_loader,
                               rows_sender=opts.sender_rows,
                               rows_receiver=opts.receiver_rows,
                               binarize=binarize,
                               receiver_bottom=True)

    sender = Sender(vocab_size=opts.vocab_size)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr)

    intervention = CallbackEvaluator(test_loader,
                                     device=opts.device,
                                     loss=game.loss,
                                     is_gs=True,
                                     var_length=False,
                                     input_intervention=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           epoch_callback=intervention,
                           as_json=True,
                           early_stopping=early_stopper)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#11
0
def main(params):
    opts = get_common_opts(params=params)
    print(f"{opts}\n")
    assert (
        not opts.batch_size % 2
    ), f"Batch size must be multiple of 2. Found {opts.batch_size} instead"
    print(
        f"Running a distruted training is set to: {opts.distributed_context.is_distributed}. "
        f"World size is {opts.distributed_context.world_size}. "
        f"Using batch of size {opts.batch_size} on {opts.distributed_context.world_size} device(s)\n"
        f"Applying augmentations: {opts.use_augmentations} with image size: {opts.image_size}.\n"
    )
    if not opts.distributed_context.is_distributed and opts.pdb:
        breakpoint()
    if opts.use_distributed_negatives and not opts.distributed_context.is_distributed:
        sys.exit("Distributed negatives cannot be used in non-distributed context")

    train_loader = get_dataloader(
        dataset_dir=opts.dataset_dir,
        dataset_name=opts.dataset_name,
        image_size=opts.image_size,
        batch_size=opts.batch_size,
        num_workers=opts.num_workers,
        is_distributed=opts.distributed_context.is_distributed,
        seed=opts.random_seed,
        use_augmentations=opts.use_augmentations,
        return_original_image=opts.return_original_image,
    )

    game = build_game(opts)

    model_parameters = add_weight_decay(game, opts.weight_decay, skip_name="bn")

    optimizer = torch.optim.SGD(
        model_parameters,
        lr=opts.lr,
        momentum=0.9,
    )
    optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=opts.n_epochs
    )

    if (
        opts.distributed_context.is_distributed
        and opts.distributed_context.world_size > 2
        and opts.use_larc
    ):
        optimizer = LARC(optimizer, trust_coefficient=0.001, clip=False, eps=1e-8)

    callbacks = get_callbacks(
        shared_vision=opts.shared_vision,
        n_epochs=opts.n_epochs,
        checkpoint_dir=opts.checkpoint_dir,
        sender=game.game.sender,
        train_gs_temperature=opts.train_gs_temperature,
        minimum_gs_temperature=opts.minimum_gs_temperature,
        update_gs_temp_frequency=opts.update_gs_temp_frequency,
        gs_temperature_decay=opts.gs_temperature_decay,
        is_distributed=opts.distributed_context.is_distributed,
    )

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        optimizer_scheduler=optimizer_scheduler,
        train_data=train_loader,
        callbacks=callbacks,
    )
    trainer.train(n_epochs=opts.n_epochs)

    data_args = {
        "image_size": opts.image_size,
        "batch_size": opts.batch_size,
        "dataset_name": "imagenet",
        "num_workers": opts.num_workers,
        "use_augmentations": False,
        "is_distributed": opts.distributed_context.is_distributed,
        "seed": opts.random_seed,
    }
    i_test_loader = get_dataloader(
        dataset_dir="/datasets01/imagenet_full_size/061417/val", **data_args
    )
    o_test_loader = get_dataloader(
        dataset_dir="/private/home/mbaroni/agentini/representation_learning/generalizaton_set_construction/80_generalization_data_set/",
        **data_args,
    )

    _, i_test_interaction = trainer.eval(i_test_loader)
    dump = dict((k, v.mean().item()) for k, v in i_test_interaction.aux.items())
    dump.update(dict(mode="VALIDATION_I_TEST"))
    print(json.dumps(dump), flush=True)

    _, o_test_interaction = trainer.eval(o_test_loader)
    dump = dict((k, v.mean().item()) for k, v in o_test_interaction.aux.items())
    dump.update(dict(mode="VALIDATION_O_TEST"))
    print(json.dumps(dump), flush=True)

    if opts.checkpoint_dir:
        output_path = Path(opts.checkpoint_dir)
        output_path.mkdir(exist_ok=True, parents=True)
        torch.save(i_test_interaction, output_path / "i_test_interaction")
        torch.save(o_test_interaction, output_path / "o_test_interaction")

    print("| FINISHED JOB")
                                   tags=['buffled_berkeley']) as experiment:
        print(os.environ)
        compositional_game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff,
            length_cost=opts.length_cost)
        # compositional_game = CompositionalGameReinforce(sender, receiver, loss)
        sender_params = [{'params': sender.parameters(), 'lr': opts.sender_lr}]
        receiver_params = [{
            'params': receiver.parameters(),
            'lr': opts.receiver_lr
        }]
        optimizer = torch.optim.Adam(sender_params + receiver_params)
        trainer = core.Trainer(
            game=compositional_game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=[
                CompositionalityMetric(full_dataset,
                                       opts,
                                       test_targets=test_targets),
                NeptuneMonitor(train_freq=100, test_freq=1),
                core.ConsoleLogger(
                    print_train_loss=not os.environ.get('SLURM_JOB_NAME')),
            ])
        trainer.train(n_epochs=opts.n_epochs)
示例#13
0
文件: play.py 项目: Slowika/EGG
def main(params):
    opts = get_params(params)
    if opts.validation_batch_size == 0:
        opts.validation_batch_size = opts.batch_size
    print(opts, flush=True)

    # the following if statement controls aspects specific to the two game tasks: loss, input data and architecture of the Receiver
    # (the Sender is identical in both cases, mapping a single input attribute-value vector to a variable-length message)
    if opts.game_type == "discri":
        # the game object we will encounter below takes as one of its mandatory arguments a loss: a loss in EGG is expected to take as arguments the sender input,
        # the message, the Receiver input, the Receiver output and the labels (although some of these elements might not actually be used by a particular loss);
        # together with the actual loss computation, the loss function can return a dictionary with other auxiliary statistics: in this case, accuracy
        def loss(
            _sender_input,
            _message,
            _receiver_input,
            receiver_output,
            labels,
            _aux_input,
        ):
            # in the discriminative case, accuracy is computed by comparing the index with highest score in Receiver output (a distribution of unnormalized
            # probabilities over target poisitions) and the corresponding label read from input, indicating the ground-truth position of the target
            acc = (receiver_output.argmax(dim=1) == labels).detach().float()
            # similarly, the loss computes cross-entropy between the Receiver-produced target-position probability distribution and the labels
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            return loss, {"acc": acc}

        # the input data are read into DataLodaer objects, which are pytorch constructs implementing standard data processing functionalities, such as shuffling
        # and batching
        # within our games, we implement dataset classes, such as AttValDiscriDataset, to read the input text files and convert the information they contain
        # into the form required by DataLoader
        # look at the definition of the AttValDiscrDataset (the class to read discrimination game data) in data_readers.py for further details
        # note that, for the training dataset, we first instantiate the AttValDiscriDataset object and then feed it to DataLoader, whereas for the
        # validation data (confusingly called "test" data due to code heritage inertia) we directly declare the AttValDiscriDataset when instantiating
        # DataLoader: the reason for this difference is that we need the train_ds object to retrieve the number of features of the input vectors
        train_ds = AttValDiscriDataset(path=opts.train_data,
                                       n_values=opts.n_values)
        train_loader = DataLoader(train_ds,
                                  batch_size=opts.batch_size,
                                  shuffle=True,
                                  num_workers=1)
        test_loader = DataLoader(
            AttValDiscriDataset(path=opts.validation_data,
                                n_values=opts.n_values),
            batch_size=opts.validation_batch_size,
            shuffle=False,
            num_workers=1,
        )
        # note that the number of features retrieved here concerns inputs after they are converted to 1-hot vectors
        n_features = train_ds.get_n_features()
        # we define here the core of the Receiver for the discriminative game, see the architectures.py file for details:
        # note that this will be embedded in a wrapper below to define the full agent
        receiver = DiscriReceiver(n_features=n_features,
                                  n_hidden=opts.receiver_hidden)

    else:  # reco game

        def loss(sender_input, _message, _receiver_input, receiver_output,
                 labels, _aux_input):
            # in the case of the recognition game, for each attribute we compute a different cross-entropy score
            # based on comparing the probability distribution produced by the Receiver over the values of each attribute
            # with the corresponding ground truth, and then averaging across attributes
            # accuracy is instead computed by considering as a hit only cases where, for each attribute, the Receiver
            # assigned the largest probability to the correct value
            # most of this function consists of the usual pytorch madness needed to reshape tensors in order to perform these computations
            n_attributes = opts.n_attributes
            n_values = opts.n_values
            batch_size = sender_input.size(0)
            receiver_output = receiver_output.view(batch_size * n_attributes,
                                                   n_values)
            receiver_guesses = receiver_output.argmax(dim=1)
            correct_samples = ((receiver_guesses == labels.view(-1)).view(
                batch_size, n_attributes).detach())
            acc = (torch.sum(correct_samples, dim=-1) == n_attributes).float()
            labels = labels.view(batch_size * n_attributes)
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            loss = loss.view(batch_size, -1).mean(dim=1)
            return loss, {"acc": acc}

        # again, see data_readers.py in this directory for the AttValRecoDataset data reading class
        train_loader = DataLoader(
            AttValRecoDataset(
                path=opts.train_data,
                n_attributes=opts.n_attributes,
                n_values=opts.n_values,
            ),
            batch_size=opts.batch_size,
            shuffle=True,
            num_workers=1,
        )
        test_loader = DataLoader(
            AttValRecoDataset(
                path=opts.validation_data,
                n_attributes=opts.n_attributes,
                n_values=opts.n_values,
            ),
            batch_size=opts.validation_batch_size,
            shuffle=False,
            num_workers=1,
        )
        # the number of features for the Receiver (input) and the Sender (output) is given by n_attributes*n_values because
        # they are fed/produce 1-hot representations of the input vectors
        n_features = opts.n_attributes * opts.n_values
        # we define here the core of the receiver for the discriminative game, see the architectures.py file for details
        # this will be embedded in a wrapper below to define the full architecture
        receiver = RecoReceiver(n_features=n_features,
                                n_hidden=opts.receiver_hidden)

    # we are now outside the block that defined game-type-specific aspects of the games: note that the core Sender architecture
    # (see architectures.py for details) is shared by the two games (it maps an input vector to a hidden layer that will be use to initialize
    # the message-producing RNN): this will also be embedded in a wrapper below to define the full architecture
    sender = Sender(n_hidden=opts.sender_hidden, n_features=n_features)

    # now, we instantiate the full sender and receiver architectures, and connect them and the loss into a game object
    # the implementation differs slightly depending on whether communication is optimized via Gumbel-Softmax ('gs') or Reinforce ('rf', default)
    if opts.mode.lower() == "gs":
        # in the following lines, we embed the Sender and Receiver architectures into standard EGG wrappers that are appropriate for Gumbel-Softmax optimization
        # the Sender wrapper takes the hidden layer produced by the core agent architecture we defined above when processing input, and uses it to initialize
        # the RNN that generates the message
        sender = core.RnnSenderGS(
            sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            temperature=opts.temperature,
        )
        # the Receiver wrapper takes the symbol produced by the Sender at each step (more precisely, in Gumbel-Softmax mode, a function of the overall probability
        # of non-eos symbols upt to the step is used), maps it to a hidden layer through a RNN, and feeds this hidden layer to the
        # core Receiver architecture we defined above (possibly with other Receiver input, as determined by the core architecture) to generate the output
        receiver = core.RnnReceiverGS(
            receiver,
            vocab_size=opts.vocab_size,
            embed_dim=opts.receiver_embedding,
            hidden_size=opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        # callback functions can be passed to the trainer object (see below) to operate at certain steps of training and validation
        # for example, the TemperatureUpdater (defined in callbacks.py in the core directory) will update the Gumbel-Softmax temperature hyperparameter
        # after each epoch
        callbacks = [
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)
        ]
    else:  # NB: any other string than gs will lead to rf training!
        # here, the interesting thing to note is that we use the same core architectures we defined above, but now we embed them in wrappers that are suited to
        # Reinforce-based optmization
        sender = core.RnnSenderReinforce(
            sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
        )
        receiver = core.RnnReceiverDeterministic(
            receiver,
            vocab_size=opts.vocab_size,
            embed_dim=opts.receiver_embedding,
            hidden_size=opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=0,
        )
        callbacks = []

    # we are almost ready to train: we define here an optimizer calling standard pytorch functionality
    optimizer = core.build_optimizer(game.parameters())
    # in the following statement, we finally instantiate the trainer object with all the components we defined (the game, the optimizer, the data
    # and the callbacks)
    if opts.print_validation_events == True:
        # we add a callback that will print loss and accuracy after each training and validation pass (see ConsoleLogger in callbacks.py in core directory)
        # if requested by the user, we will also print a detailed log of the validation pass after full training: look at PrintValidationEvents in
        # language_analysis.py (core directory)
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=callbacks + [
                core.ConsoleLogger(print_train_loss=True, as_json=True),
                core.PrintValidationEvents(n_epochs=opts.n_epochs),
            ],
        )
    else:
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=callbacks +
            [core.ConsoleLogger(print_train_loss=True, as_json=True)],
        )

    # and finally we train!
    trainer.train(n_epochs=opts.n_epochs)
示例#14
0
def main(params):
    opts = get_params(params)
    print(opts)

    kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST("./data",
                                   train=True,
                                   download=True,
                                   transform=transform)
    test_dataset = datasets.MNIST("./data",
                                  train=False,
                                  download=False,
                                  transform=transform)

    n_classes = 10

    corrupt_labels_(dataset=train_dataset,
                    p_corrupt=opts.p_corrupt,
                    seed=opts.random_seed + 1)
    label_mapping = torch.LongTensor([x % n_classes for x in range(100)])

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    train_loader = DoubleMnist(train_loader, label_mapping)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)
    test_loader = DoubleMnist(test_loader, label_mapping)

    deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1
    deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1

    sender = Sender(
        vocab_size=opts.vocab_size,
        deeper=deeper_alice,
        linear_channel=opts.linear_channel == 1,
        softmax_channel=opts.softmax_non_linearity == 1,
    )
    receiver = Receiver(vocab_size=opts.vocab_size,
                        n_classes=n_classes,
                        deeper=deeper_bob)

    if (opts.softmax_non_linearity != 1 and opts.linear_channel != 1
            and opts.force_discrete != 1):
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)
    elif opts.force_discrete == 1:
        sender = core.GumbelSoftmaxWrapper(sender,
                                           temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[
            core.ConsoleLogger(as_json=True, print_train_loss=True),
            EarlyStopperAccuracy(opts.early_stopping_thr),
        ],
    )

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#15
0
def main(params):
    print(torch.cuda.is_available())
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    #elif opts.probs == "creneau":
    #    ones = np.ones(int(opts.n_features/2))
    #    tens = 10*np.ones(opts.n_features-int(opts.n_features/2))
    #    probs = np.concatenate((tens,ones),axis=0)
    #elif opts.probs == "toy":
    #    fives = 5*np.ones(int(opts.n_features/10))
    #    ones = np.ones(opts.n_features-int(opts.n_features/10))
    #    probs = np.concatenate((fives,ones),axis=0)
    #elif opts.probs == "escalier":
    #    ones = np.ones(int(opts.n_features/4))
    #    tens = 10*np.ones(int(opts.n_features/4))
    #    huns = 100*np.ones(int(opts.n_features/4))
    #    thous = 1000*np.ones(opts.n_features-3*int(opts.n_features/4))
    #    probs = np.concatenate((thous,huns,tens,ones),axis=0)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')],
                         dtype=np.float32)

    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch,
                                probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features,
                        n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
            num_heads=5,
            hidden_size=opts.sender_hidden,
            force_eos=opts.force_eos,
            generate_style=opts.sender_generate_style,
            causal=opts.causal_sender)
    else:
        if opts.sender_cell == 'cnn':
            sender = Sender(n_features=opts.n_features,
                            n_hidden=opts.sender_hidden)
            sender = CnnSenderReinforce(sender,
                                        opts.vocab_size,
                                        opts.sender_embedding,
                                        opts.sender_hidden,
                                        cell=opts.sender_cell,
                                        max_len=opts.max_len,
                                        num_layers=opts.sender_num_layers,
                                        force_eos=force_eos)

        else:
            sender = Sender(n_features=opts.n_features,
                            n_hidden=opts.sender_hidden)

            sender = core.RnnSenderReinforce(sender,
                                             opts.vocab_size,
                                             opts.sender_embedding,
                                             opts.sender_hidden,
                                             cell=opts.sender_cell,
                                             max_len=opts.max_len,
                                             num_layers=opts.sender_num_layers,
                                             force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features,
                            n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.max_len,
            opts.receiver_embedding,
            5,
            opts.receiver_hidden,
            opts.receiver_num_layers,
            causal=opts.causal_receiver)
    else:
        if opts.receiver_cell == 'cnn':
            receiver = Receiver(n_features=opts.n_features,
                                n_hidden=opts.receiver_hidden)
            receiver = CnnReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.receiver_embedding,
                opts.receiver_hidden,
                cell=opts.receiver_cell,
                num_layers=opts.receiver_num_layers)

        else:
            receiver = Receiver(n_features=opts.n_features,
                                n_hidden=opts.receiver_hidden)
            receiver = core.RnnReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.receiver_embedding,
                opts.receiver_hidden,
                cell=opts.receiver_cell,
                num_layers=opts.receiver_num_layers)
            if not opts.impatient:
                receiver = Receiver(n_features=opts.n_features,
                                    n_hidden=opts.receiver_hidden)
                receiver = core.RnnReceiverDeterministic(
                    receiver,
                    opts.vocab_size,
                    opts.receiver_embedding,
                    opts.receiver_hidden,
                    cell=opts.receiver_cell,
                    num_layers=opts.receiver_num_layers)
            else:
                receiver = Receiver(n_features=opts.receiver_hidden,
                                    n_hidden=opts.vocab_size)
                # If impatient 1
                receiver = RnnReceiverImpatient(
                    receiver,
                    opts.vocab_size,
                    opts.receiver_embedding,
                    opts.receiver_hidden,
                    cell=opts.receiver_cell,
                    num_layers=opts.receiver_num_layers,
                    max_len=opts.max_len,
                    n_features=opts.n_features)
                # If impatient 2
                #receiver = RnnReceiverImpatient2(receiver, opts.vocab_size, opts.receiver_embedding,
            #                                         opts.receiver_hidden, cell=opts.receiver_cell,
            #                                         num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features)

    if not opts.impatient:
        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff,
            length_cost=opts.length_cost,
            unigram_penalty=opts.unigram_pen,
            reg=opts.reg)
    else:
        game = SenderImpatientReceiverRnnReinforce(
            sender,
            receiver,
            loss_impatient,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff,
            length_cost=opts.length_cost,
            unigram_penalty=opts.unigram_pen,
            reg=opts.reg)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])

    for epoch in range(int(opts.n_epochs)):

        print("Epoch: " + str(epoch))

        if epoch % 100 == 0:
            trainer.optimizer.defaults["lr"] /= 2

        trainer.train(n_epochs=1)
        if opts.checkpoint_dir:
            trainer.save_checkpoint(
                name=
                f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}'
            )

        if not opts.impatient:
            acc_vec, messages = dump(trainer.game, opts.n_features, device,
                                     False, epoch)
        else:
            acc_vec, messages = dump_impatient(trainer.game, opts.n_features,
                                               device, False, epoch)

        # ADDITION TO SAVE MESSAGES
        all_messages = []
        for x in messages:
            x = x.cpu().numpy()
            all_messages.append(x)
        all_messages = np.asarray(all_messages)

        if epoch % 50 == 0:
            torch.save(
                sender.state_dict(),
                opts.dir_save + "/sender/sender_weights" + str(epoch) + ".pth")
            torch.save(
                receiver.state_dict(), opts.dir_save +
                "/receiver/receiver_weights" + str(epoch) + ".pth")
            #print(acc_vec)

        np.save(opts.dir_save + '/messages/messages_' + str((epoch)) + '.npy',
                all_messages)
        np.save(opts.dir_save + '/accuracy/accuracy_' + str((epoch)) + '.npy',
                acc_vec)

    core.close()
示例#16
0
文件: train.py 项目: robertodessi/EGG
def main(params):
    opts = get_opts(params=params)
    print(opts)
    print(
        f"Running a distruted training is set to: {opts.distributed_context.is_distributed}. "
        f"World size is {opts.distributed_context.world_size}\n"
        f"Using imagenet with image size: {opts.image_size}. "
        f"Using batch of size {opts.batch_size} on {opts.distributed_context.world_size} device(s)"
    )
    if not opts.distributed_context.is_distributed and opts.pdb:
        breakpoint()

    train_loader = get_dataloader(
        dataset_dir=opts.dataset_dir,
        dataset_name=opts.dataset_name,
        image_size=opts.image_size,
        batch_size=opts.batch_size,
        num_workers=opts.num_workers,
        is_distributed=opts.distributed_context.is_distributed,
        seed=opts.random_seed
    )

    simclr_game = build_game(
        batch_size=opts.batch_size,
        loss_temperature=opts.ntxent_tau,
        vision_encoder_name=opts.model_name,
        output_size=opts.output_size,
        is_distributed=opts.distributed_context.is_distributed
    )

    model_parameters = add_weight_decay(
        simclr_game,
        opts.weight_decay,
        skip_name='bn'
    )

    optimizer_original = torch.optim.SGD(
        model_parameters,
        lr=opts.lr,
        momentum=0.9,
    )
    optimizer = LARC(optimizer_original, trust_coefficient=0.001, clip=False, eps=1e-8)
    optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_original, T_max=opts.n_epochs)

    callbacks = [
        core.ConsoleLogger(as_json=True, print_train_loss=True),
        BestStatsTracker(),
        VisionModelSaver()
    ]

    if opts.distributed_context.is_distributed:
        callbacks.append(DistributedSamplerEpochSetter())

    trainer = core.Trainer(
        game=simclr_game,
        optimizer=optimizer,
        optimizer_scheduler=optimizer_scheduler,
        train_data=train_loader,
        callbacks=callbacks
    )
    trainer.train(n_epochs=opts.n_epochs)
示例#17
0
文件: train.py 项目: KIKOU2016/EGG
                                  force_eos=opts.force_eos)

        receiver = core.RnnReceiverGS(receiver,
                                      opts.vocab_size,
                                      opts.receiver_embedding,
                                      opts.receiver_hidden,
                                      cell=opts.receiver_cell)

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        callback = sender.update_temp(0.9, 0.1)
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([{
        'params': game.sender.parameters(),
        'lr': opts.sender_lr
    }, {
        'params': game.receiver.parameters(),
        'lr': opts.receiver_lr
    }])

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           epoch_callback=callback,
                           callbacks=[core.ConsoleLogger(as_json=True)])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#18
0
        train_loader, test_loader = get_dsprites_dataloader(path_to_data=root,
                                                            batch_size=opts.batch_size, subsample=opts.subsample, image=True)
>>>>>>> bb004d266a1e366a60d38b998fc8664c1d05e53b
        image_shape = (64, 64)
        #train_loader, test_loader = get_dsprites(batch_size=opts.batch_size, subsample=opts.subsample)





    sender = VisualSender(z_dim=opts.z_dim)
    receiver = VisualReceiver(z_dim=opts.z_dim)
    game = betaVAE_Game(sender, receiver, z_dim=opts.z_dim, beta=opts.beta)

    optimizer = core.build_optimizer(game.parameters())


    # initialize and launch the trainer
    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True),
                                      ImageDumpCallback(test_loader.dataset, image_shape=image_shape),
                                      TopographicSimilarityLatents('euclidean', 'euclidean'), PosDisent(print_test=False, print_train=True)])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()


if __name__ == "__main__":
    import sys
    main(sys.argv[1:])
示例#19
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)

    # For compatibility, after https://github.com/facebookresearch/EGG/pull/130
    # the meaning of `length` changed a bit. Before it included the EOS symbol; now
    # it doesn't. To ensure that hyperparameters/CL arguments do not change,
    # we subtract it here.
    opts.max_len -= 1

    device = opts.device

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')],
                         dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch,
                                probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features,
                        n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
            num_heads=opts.sender_num_heads,
            hidden_size=opts.sender_hidden,
            generate_style=opts.sender_generate_style,
            causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features,
                        n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                         opts.vocab_size,
                                         opts.sender_embedding,
                                         opts.sender_hidden,
                                         cell=opts.sender_cell,
                                         max_len=opts.max_len,
                                         num_layers=opts.sender_num_layers)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features,
                            n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.max_len,
            opts.receiver_embedding,
            opts.receiver_num_heads,
            opts.receiver_hidden,
            opts.receiver_num_layers,
            causal=opts.causal_receiver)
    else:
        receiver = Receiver(n_features=opts.n_features,
                            n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
            num_layers=opts.receiver_num_layers)

    empty_logger = LoggingStrategy.minimal()
    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=opts.receiver_entropy_coeff,
        train_logging_strategy=empty_logger,
        length_cost=opts.length_cost)

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [
        EarlyStopperAccuracy(opts.early_stopping_thr),
        core.ConsoleLogger(as_json=True, print_train_loss=True)
    ]

    if opts.checkpoint_dir:
        checkpoint_name = f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}'
        callbacks.append(
            core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir,
                                 prefix=checkpoint_name))

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)

    game.logging_strategy = LoggingStrategy.maximal()  # now log everything
    dump(trainer.game, opts.n_features, device, False)
    core.close()
        'lr': opts.executive_sender_lr
    }]
    receivers_params = [{
        'params': receiver_ensemble_1.parameters(),
        'lr': opts.receiver_lr
    }, {
        'params': receiver_ensemble_2.parameters(),
        'lr': opts.receiver_lr
    }]
    optimizer = torch.optim.Adam(sender_params + receivers_params +
                                 executive_sender_params,
                                 weight_decay=1e-6)

    neptune.init('tomekkorbak/compositionality')
    with neptune.create_experiment(params=vars(opts),
                                   upload_source_files=get_filepaths(),
                                   tags=[]) as experiment:
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=[
                CompositionalityMeasurer(experiment, full_dataset, opts,
                                         test.indices),
                NeptuneMonitor(experiment=experiment),
                core.ConsoleLogger(print_train_loss=True),
                # TemperatureUpdater(experiment, agent=senders[0], decay=0.999, minimum=0.5)
            ])
        trainer.train(n_epochs=opts.n_epochs)
示例#21
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    train_loader = OneHotLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.n_examples_per_epoch /
                                opts.batch_size)

    test_loader = UniformLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r)
    test_loader.batch = [x.to(device) for x in test_loader.batch]

    if not opts.variable_length:
        sender = Sender(n_bits=opts.n_bits,
                        n_hidden=opts.sender_hidden,
                        vocab_size=opts.vocab_size)
        if opts.mode == 'gs':
            sender = core.GumbelSoftmaxWrapper(agent=sender,
                                               temperature=opts.temperature)
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)
            game = core.SymbolGameGS(sender, receiver, diff_loss)
        elif opts.mode == 'rf':
            sender = core.ReinforceWrapper(agent=sender)
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)
            receiver = core.ReinforceDeterministicWrapper(agent=receiver)
            game = core.SymbolGameReinforce(
                sender,
                receiver,
                diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff)
        elif opts.mode == 'non_diff':
            sender = core.ReinforceWrapper(agent=sender)
            receiver = ReinforcedReceiver(n_bits=opts.n_bits,
                                          n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)

            game = core.SymbolGameReinforce(
                sender,
                receiver,
                non_diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff,
                receiver_entropy_coeff=opts.receiver_entropy_coeff)
    else:
        if opts.mode != 'rf':
            print('Only mode=rf is supported atm')
            opts.mode = 'rf'

        if opts.sender_cell == 'transformer':
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden)  # TODO: not really vocab
            sender = core.TransformerSenderReinforce(
                agent=sender,
                vocab_size=opts.vocab_size,
                embed_dim=opts.sender_emb,
                max_len=opts.max_len,
                num_layers=1,
                num_heads=1,
                hidden_size=opts.sender_hidden)
        else:
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden)  # TODO: not really vocab
            sender = core.RnnSenderReinforce(agent=sender,
                                             vocab_size=opts.vocab_size,
                                             embed_dim=opts.sender_emb,
                                             hidden_size=opts.sender_hidden,
                                             max_len=opts.max_len,
                                             cell=opts.sender_cell)

        if opts.receiver_cell == 'transformer':
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_emb)
            receiver = core.TransformerReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.max_len,
                opts.receiver_emb,
                num_heads=1,
                hidden_size=opts.receiver_hidden,
                num_layers=1)
        else:
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.RnnReceiverDeterministic(receiver,
                                                     opts.vocab_size,
                                                     opts.receiver_emb,
                                                     opts.receiver_hidden,
                                                     cell=opts.receiver_cell)

            game = core.SenderReceiverRnnGS(sender, receiver, diff_loss)

        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            diff_loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff)

    optimizer = torch.optim.Adam([
        dict(params=sender.parameters(), lr=opts.sender_lr),
        dict(params=receiver.parameters(), lr=opts.receiver_lr)
    ])

    loss = game.loss

    intervention = CallbackEvaluator(test_loader,
                                     device=device,
                                     is_gs=opts.mode == 'gs',
                                     loss=loss,
                                     var_length=opts.variable_length,
                                     input_intervention=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True),
                               EarlyStopperAccuracy(opts.early_stopping_thr),
                               intervention
                           ])

    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#22
0
文件: train.py 项目: Slowika/EGG
def main(params):
    opts = get_common_opts(params=params)
    print(f"{opts}\n")
    assert (
        not opts.batch_size % 2
    ), f"Batch size must be multiple of 2. Found {opts.batch_size} instead"
    print(
        f"Running a distruted training is set to: {opts.distributed_context.is_distributed}. "
        f"World size is {opts.distributed_context.world_size}. "
        f"Using batch of size {opts.batch_size} on {opts.distributed_context.world_size} device(s)\n"
        f"Applying augmentations: {opts.use_augmentations} with image size: {opts.image_size}.\n"
    )
    if not opts.distributed_context.is_distributed and opts.pdb:
        breakpoint()

    train_loader = get_dataloader(
        dataset_dir=opts.dataset_dir,
        dataset_name=opts.dataset_name,
        image_size=opts.image_size,
        batch_size=opts.batch_size,
        num_workers=opts.num_workers,
        is_distributed=opts.distributed_context.is_distributed,
        seed=opts.random_seed,
        use_augmentations=opts.use_augmentations,
        return_original_image=opts.return_original_image,
    )

    game = build_game(opts)

    model_parameters = add_weight_decay(game, opts.weight_decay, skip_name="bn")

    optimizer = torch.optim.SGD(
        model_parameters,
        lr=opts.lr,
        momentum=0.9,
    )
    optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=opts.n_epochs
    )

    if (
        opts.distributed_context.is_distributed
        and opts.distributed_context.world_size > 2
        and opts.use_larc
    ):
        optimizer = LARC(optimizer, trust_coefficient=0.001, clip=False, eps=1e-8)

    callbacks = get_callbacks(
        shared_vision=opts.shared_vision,
        n_epochs=opts.n_epochs,
        checkpoint_dir=opts.checkpoint_dir,
        sender=game.game.sender,
        train_gs_temperature=opts.train_gs_temperature,
        minimum_gs_temperature=opts.minimum_gs_temperature,
        update_gs_temp_frequency=opts.update_gs_temp_frequency,
        gs_temperature_decay=opts.gs_temperature_decay,
        is_distributed=opts.distributed_context.is_distributed,
    )

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        optimizer_scheduler=optimizer_scheduler,
        train_data=train_loader,
        callbacks=callbacks,
    )
    trainer.train(n_epochs=opts.n_epochs)

    print("| FINISHED JOB")
示例#23
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
    elif opts.probs == 'perso':
        probs = opts.n_features+1 - np.arange(1, opts.n_features+1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch, probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size,
                                                 embed_dim=opts.sender_embedding, max_len=opts.max_len,
                                                 num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads,
                                                 hidden_size=opts.sender_hidden,
                                                 force_eos=opts.force_eos,
                                                 generate_style=opts.sender_generate_style,
                                                 causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                   opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len,
                                                         opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden,
                                                         opts.receiver_num_layers, causal=opts.causal_receiver)
    else:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding,
                                             opts.receiver_hidden, cell=opts.receiver_cell,
                                             num_layers=opts.receiver_num_layers)

    game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr),
                                      core.ConsoleLogger(as_json=True, print_train_loss=True)])

    """ mode accuracy chope a chaque epoch
    accs=[]
    all_messages,acc=dump(trainer.game, opts.n_features, device, False)
    np.save('messages_0.npy', all_messages)
    accs.append(acc)
    for i in range(int(opts.n_epochs)):
        print(i)
        trainer.train(n_epochs=1)
        all_messages,acc=dump(trainer.game, opts.n_features, device, False)
        np.save('messages_'+str((i+1))+'.npy', all_messages)
        accs.append(acc)
    np.save('accuracy.npy',accs)
    """

    trainer.train(n_epochs=opts.n_epochs)

    #if opts.checkpoint_dir:
        #trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}')
    #for i in range(30):
    #        for k in range(30):
    #        if i<k:
    #            all_messages=dump(trainer.game, opts.n_features, device, False,pos_m=i,pos_M=k)



    all_messages=dump(trainer.game, opts.n_features, device, False)


    print(all_messages)

    #freq=np.zeros(30)
    #for message in all_messages[0]:
    #        if i in range(message.shape[0]):
    #        freq[int(message[i])]+=1
    #print(freq)

    core.close()
示例#24
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    n_a, n_v = opts.n_a, opts.n_v
    opts.vocab_size = n_v

    train_data = AttributeValueData(n_attributes=n_a,
                                    n_values=n_v,
                                    mul=1,
                                    mode='train')
    train_loader = DataLoader(train_data,
                              batch_size=opts.batch_size,
                              shuffle=True)

    test_data = AttributeValueData(n_attributes=n_a,
                                   n_values=n_v,
                                   mul=1,
                                   mode='test')
    test_loader = DataLoader(test_data,
                             batch_size=opts.batch_size,
                             shuffle=False)

    print(f'# Size of train {len(train_data)} test {len(test_data)}')

    if opts.language == 'identity':
        sender = IdentitySender(n_a, n_v)
    elif opts.language == 'rotated':
        sender = RotatedSender(n_a, n_v)
    else:
        assert False

    receiver = Receiver(n_hidden=opts.receiver_hidden,
                        n_dim=n_a * n_v,
                        inner_layers=opts.receiver_layers)
    receiver = core.RnnReceiverDeterministic(
        receiver,
        opts.vocab_size + 1,  # exclude eos = 0
        opts.receiver_emb,
        opts.receiver_hidden,
        cell=opts.receiver_cell,
        num_layers=opts.cell_layers)

    diff_loss = DiffLoss(n_a, n_v, loss_type=opts.loss_type)

    game = core.SenderReceiverRnnReinforce(sender,
                                           receiver,
                                           diff_loss,
                                           receiver_entropy_coeff=0.05,
                                           sender_entropy_coeff=0.0)

    optimizer = core.build_optimizer(receiver.parameters())
    loss = game.loss

    early_stopper = core.EarlyStopperAccuracy(1.0, validation=False)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=True),
                               early_stopper
                           ],
                           grad_norm=1.0)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#25
0
文件: train.py 项目: Shawn-Guo-CN/EGG
def main(params):
    opts = get_params(params)

    device = torch.device('cuda' if opts.cuda else 'cpu')

    data_loader = VectorsLoader(
        perceptual_dimensions=opts.perceptual_dimensions,
        n_distractors=opts.n_distractors,
        batch_size=opts.batch_size,
        train_samples=opts.train_samples,
        validation_samples=opts.validation_samples,
        test_samples=opts.test_samples,
        shuffle_train_data=opts.shuffle_train_data,
        dump_data_folder=opts.dump_data_folder,
        load_data_path=opts.load_data_path,
        seed=opts.data_seed)

    train_data, validation_data, test_data = data_loader.get_iterators()

    data_loader.upd_cl_options(opts)

    if opts.max_len > 1:
        baseline_msg = 'Cannot yet compute "smart" baseline value for messages of length greater than 1'
    else:
        baseline_msg = f'\n| Baselines measures with {opts.n_distractors} distractors and messages of max_len = {opts.max_len}:\n' \
            f'| Dummy random baseline: accuracy = {1 / (opts.n_distractors + 1)}\n'
        if -1 not in opts.perceptual_dimensions:
            baseline_msg += f'| "Smart" baseline with perceptual_dimensions {opts.perceptual_dimensions} = {compute_baseline_accuracy(opts.n_distractors, opts.max_len, *opts.perceptual_dimensions)}\n'
        else:
            baseline_msg += f'| Data was loaded froman external file, thus no perceptual_dimension vector was provided, "smart baseline" cannot be computed\n'

    print(baseline_msg)

    sender = Sender(n_features=data_loader.n_features,
                    n_hidden=opts.sender_hidden)

    receiver = Receiver(n_features=data_loader.n_features,
                        linear_units=opts.receiver_hidden)

    if opts.mode.lower() == 'gs':
        sender = core.RnnSenderGS(sender,
                                  opts.vocab_size,
                                  opts.sender_embedding,
                                  opts.sender_hidden,
                                  cell=opts.sender_cell,
                                  max_len=opts.max_len,
                                  temperature=opts.temperature,
                                  force_eos=False)

        receiver = core.RnnReceiverGS(receiver,
                                      opts.vocab_size,
                                      opts.receiver_embedding,
                                      opts.receiver_hidden,
                                      cell=opts.receiver_cell)

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([{
        'params': game.sender.parameters(),
        'lr': opts.sender_lr
    }, {
        'params': game.receiver.parameters(),
        'lr': opts.receiver_lr
    }])
    callbacks = [core.ConsoleLogger(as_json=True)]
    if opts.mode.lower() == 'gs':
        callbacks.append(
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1))
    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_data,
                           validation_data=validation_data,
                           callbacks=callbacks)
    trainer.train(n_epochs=opts.n_epochs)

    if opts.evaluate:
        is_gs = opts.mode == 'gs'
        sender_inputs, messages, receiver_inputs, receiver_outputs, labels = dump_sender_receiver(
            game, test_data, is_gs, variable_length=True, device=device)

        receiver_outputs = move_to(receiver_outputs, device)
        labels = move_to(labels, device)

        receiver_outputs = torch.stack(receiver_outputs)
        labels = torch.stack(labels)

        tensor_accuracy = receiver_outputs.argmax(dim=1) == labels
        accuracy = torch.mean(tensor_accuracy.float()).item()

        unique_dict = {}

        for elem in sender_inputs:
            target = ""
            for dim in elem:
                target += f'{str(int(dim.item()))}-'
            target = target[:-1]
            if target not in unique_dict:
                unique_dict[target] = True

        print(f'| Accuracy on test set: {accuracy}')

        compute_mi_input_msgs(sender_inputs, messages)

        print(f'entropy sender inputs {entropy(sender_inputs)}')
        print(f'mi sender inputs msgs {mutual_info(sender_inputs, messages)}')

        if opts.dump_msg_folder:
            opts.dump_msg_folder.mkdir(exist_ok=True)
            msg_dict = {}

            output_msg = f'messages_{opts.perceptual_dimensions}_vocab_{opts.vocab_size}' \
                        f'_maxlen_{opts.max_len}_bsize_{opts.batch_size}' \
                        f'_n_distractors_{opts.n_distractors}_train_size_{opts.train_samples}' \
                        f'_valid_size_{opts.validation_samples}_test_size_{opts.test_samples}' \
                        f'_slr_{opts.sender_lr}_rlr_{opts.receiver_lr}_shidden_{opts.sender_hidden}' \
                        f'_rhidden_{opts.receiver_hidden}_semb_{opts.sender_embedding}' \
                        f'_remb_{opts.receiver_embedding}_mode_{opts.mode}' \
                        f'_scell_{opts.sender_cell}_rcell_{opts.receiver_cell}.msg'

            output_file = opts.dump_msg_folder / output_msg
            with open(output_file, 'w') as f:
                f.write(f'{opts}\n')
                for sender_input, message, receiver_input, receiver_output, label \
                        in zip(sender_inputs, messages, receiver_inputs, receiver_outputs, labels):
                    sender_input = ','.join(map(str, sender_input.tolist()))
                    message = ','.join(map(str, message.tolist()))
                    distractors_list = receiver_input.tolist()
                    receiver_input = '; '.join([
                        ','.join(map(str, elem)) for elem in distractors_list
                    ])
                    if is_gs: receiver_output = receiver_output.argmax()
                    f.write(
                        f'{sender_input} -> {receiver_input} -> {message} -> {receiver_output} (label={label.item()})\n'
                    )

                    if message in msg_dict:
                        msg_dict[message] += 1
                    else:
                        msg_dict[message] = 1

                sorted_msgs = sorted(msg_dict.items(),
                                     key=operator.itemgetter(1),
                                     reverse=True)
                f.write(
                    f'\nUnique target vectors seen by sender: {len(unique_dict.keys())}\n'
                )
                f.write(
                    f'Unique messages produced by sender: {len(msg_dict.keys())}\n'
                )
                f.write(f"Messagses: 'msg' : msg_count: {str(sorted_msgs)}\n")
                f.write(f'\nAccuracy: {accuracy}')
示例#26
0
文件: train.py 项目: Gromite/EGG
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')],
                         dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch,
                                probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features,
                        n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
            num_heads=opts.sender_num_heads,
            hidden_size=opts.sender_hidden,
            force_eos=opts.force_eos,
            generate_style=opts.sender_generate_style,
            causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features,
                        n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                         opts.vocab_size,
                                         opts.sender_embedding,
                                         opts.sender_hidden,
                                         cell=opts.sender_cell,
                                         max_len=opts.max_len,
                                         num_layers=opts.sender_num_layers,
                                         force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features,
                            n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.max_len,
            opts.receiver_embedding,
            opts.receiver_num_heads,
            opts.receiver_hidden,
            opts.receiver_num_layers,
            causal=opts.causal_receiver)
    else:
        receiver = Receiver(n_features=opts.n_features,
                            n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
            num_layers=opts.receiver_num_layers)

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=opts.receiver_entropy_coeff,
        length_cost=opts.length_cost)

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [
        EarlyStopperAccuracy(opts.early_stopping_thr),
        core.ConsoleLogger(as_json=True, print_train_loss=True)
    ]

    if opts.checkpoint_dir:
        checkpoint_name = f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}'
        callbacks.append(
            core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir,
                                 prefix=checkpoint_name))

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)

    dump(trainer.game, opts.n_features, device, False)
    core.close()
示例#27
0
        game = SenderReceiverRnnGS(sender, receiver, loss)

    optimizer = torch.optim.Adam([{
        'params': game.sender.parameters(),
        'lr': opts.sender_lr
    }, {
        'params': game.receiver.parameters(),
        'lr': opts.receiver_lr
    }])

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=[
                               core.ConsoleLogger(print_train_loss=True,
                                                  as_json=True),
                               TopographicSimilarity(
                                   sender_input_distance_fn='cosine',
                                   message_distance_fn='edit',
                                   is_gumbel=True)
                           ])

    if dump_loader is not None:
        print("Printing vocab...")
        lines = []
        for vector, label in dump_loader:
            vector = torch.tensor(vector).cuda()
            with torch.no_grad():
                message, _ = sender(vector)
                message = message.argmax(dim=-1).cpu().numpy()[0]
                message = ' '.join([str(i) for i in message])
示例#28
0
文件: train.py 项目: evdcush/EGG
def main(params):
    opts = get_params(params)

    device = torch.device("cuda" if opts.cuda else "cpu")
    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch)
    test_loader = OneHotLoader(n_features=opts.n_features,
                               batch_size=opts.batch_size,
                               batches_per_epoch=opts.batches_per_epoch,
                               seed=7)

    sender = Sender(n_hidden=opts.sender_hidden, n_features=opts.n_features)
    receiver = Receiver(n_features=opts.n_features,
                        n_hidden=opts.receiver_hidden)

    if opts.mode.lower() == 'rf':
        sender = core.RnnSenderReinforce(sender,
                                         opts.vocab_size,
                                         opts.sender_embedding,
                                         opts.sender_hidden,
                                         cell=opts.sender_cell,
                                         max_len=opts.max_len)
        receiver = core.RnnReceiverDeterministic(receiver,
                                                 opts.vocab_size,
                                                 opts.receiver_embedding,
                                                 opts.receiver_hidden,
                                                 cell=opts.receiver_cell)

        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff)
        callbacks = []
    elif opts.mode.lower() == 'gs':
        sender = core.RnnSenderGS(sender,
                                  opts.vocab_size,
                                  opts.sender_embedding,
                                  opts.sender_hidden,
                                  cell=opts.sender_cell,
                                  max_len=opts.max_len,
                                  temperature=opts.temperature)

        receiver = core.RnnReceiverGS(receiver,
                                      opts.vocab_size,
                                      opts.receiver_embedding,
                                      opts.receiver_hidden,
                                      cell=opts.receiver_cell)

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        callbacks = [
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)
        ]
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([{
        'params': game.sender.parameters(),
        'lr': opts.sender_lr
    }, {
        'params': game.receiver.parameters(),
        'lr': opts.receiver_lr
    }])

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=callbacks +
                           [core.ConsoleLogger(as_json=True)])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
        'lr': opts.lr
    } for agent in agents])
    neptune.init(project_qualified_name='anonymous/anonymous',
                 backend=neptune.OfflineBackend())
    with neptune.create_experiment(params=vars(opts),
                                   upload_source_files=get_filepaths(),
                                   tags=['']) as experiment:
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=[
                CompositionalityMetricObverter(full_dataset,
                                               agents[0],
                                               opts,
                                               opts.vocab_size,
                                               prefix='1_'),
                CompositionalityMetricObverter(full_dataset,
                                               agents[1],
                                               opts,
                                               opts.vocab_size,
                                               prefix='2_'),
                NeptuneMonitor(),
                core.ConsoleLogger(print_train_loss=not opts.on_slurm),
                EarlyStopperAccuracy(threshold=0.99,
                                     field_name='accuracy',
                                     delay=5)
            ])
        trainer.train(n_epochs=opts.n_epochs)
示例#30
0
def main(params):
    import copy
    opts = get_params(params)
    device = opts.device

    train_loader, validation_loader = get_dsprites_dataloader(
        path_to_data='egg/zoo/data_loaders/data/dsprites.npz',
        batch_size=opts.batch_size,
        subsample=opts.subsample,
        image=False)

    n_dim = opts.n_attributes

    if opts.receiver_cell in ['lstm', 'rnn', 'gru']:
        receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.RnnReceiverDeterministic(receiver,
                                                 opts.vocab_size + 1,
                                                 opts.receiver_emb,
                                                 opts.receiver_hidden,
                                                 cell=opts.receiver_cell)
    else:
        raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}')

    if opts.sender_cell in ['lstm', 'rnn', 'gru']:
        sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        sender = core.RnnSenderReinforce(agent=sender,
                                         vocab_size=opts.vocab_size,
                                         embed_dim=opts.sender_emb,
                                         hidden_size=opts.sender_hidden,
                                         max_len=opts.max_len,
                                         force_eos=False,
                                         cell=opts.sender_cell)
    else:
        raise ValueError(f'Unknown sender cell, {opts.sender_cell}')

    sender = PlusOneWrapper(sender)
    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        'no': core.baselines.NoBaseline,
        'mean': core.baselines.MeanBaseline,
        'builtin': core.baselines.BuiltInBaseline
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline)
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    latent_values, _ = zip(*[batch for batch in validation_loader])
    latent_values = latent_values[0]

    metrics_evaluator = Metrics(latent_values,
                                opts.device,
                                opts.n_attributes,
                                opts.n_values,
                                opts.vocab_size + 1,
                                freq=opts.stats_freq)

    loaders = []
    #loaders.append(("generalization hold out", generalization_holdout_loader, DiffLoss(
    #opts.n_attributes, opts.n_values, generalization=True)))
    loaders.append(("uniform holdout", validation_loader,
                    DiffLoss(opts.n_attributes, opts.n_values)))

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                         validation=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=False),
                               early_stopper, metrics_evaluator,
                               holdout_evaluator
                           ])
    trainer.train(n_epochs=opts.n_epochs)

    last_epoch_interaction = early_stopper.validation_stats[-1][1]
    validation_acc = last_epoch_interaction.aux['acc'].mean()

    uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc']

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random
            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(sender,
                                                   receiver,
                                                   loss,
                                                   sender_entropy_coeff=0.0,
                                                   receiver_entropy_coeff=0.0)
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                                 validation=True)

            trainer = core.Trainer(game=game,
                                   optimizer=optimizer,
                                   train_data=train_loader,
                                   validation_data=validation_loader,
                                   callbacks=[
                                       early_stopper,
                                       Evaluator(loaders, opts.device, freq=0)
                                   ])
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]['acc'] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():            return \
core.RnnReceiverDeterministic(Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru')

        def small_gru_receiver_generator():            return \
core.RnnReceiverDeterministic(
                Receiver(n_hidden=100, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=100, cell='gru')

        def tiny_gru_receiver_generator():            return \
core.RnnReceiverDeterministic(
                Receiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru')

        def nonlinear_receiver_generator():            return \
NonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1,
            max_length=opts.max_len, n_hidden=opts.receiver_hidden)

        for name, receiver_generator in [
            ('gru', gru_receiver_generator),
            ('nonlinear', nonlinear_receiver_generator),
            ('tiny_gru', tiny_gru_receiver_generator),
            ('small_gru', small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps({
                        "mode": "reset",
                        "seed": seed,
                        "receiver_name": name,
                        "auc": auc
                    }))

    print('---End--')

    core.close()