示例#1
0
def get_synthesizer_model(
        vocab_size=len(hparams.VOCAB),
        char_embed_size=hparams.CHAR_EMBED_SIZE,
        spk_embed_size=hparams.SPK_EMBED_SIZE,
        enc_conv1_bank_depth=hparams.ENC_CONV1_BANK_DEPTH,
        enc_convprojec_filters1=hparams.ENC_CONVPROJEC_FILTERS1,
        enc_convprojec_filters2=hparams.ENC_CONVPROJEC_FILTERS2,
        enc_highway_depth=hparams.ENC_HIGHWAY_DEPTH,
        hidden_size=hparams.HIDDEN_SIZE,
        post_conv1_bank_depth=hparams.POST_CONV1_BANK_DEPTH,
        post_convprojec_filters1=hparams.POST_CONVPROJEC_FILTERS1,
        post_convprojec_filters2=hparams.POST_CONVPROJEC_FILTERS2,
        post_highway_depth=hparams.POST_HIGHWAY_DEPTH,
        attention_dim=hparams.ATTENTION_DIM,
        target_size=hparams.TARGET_MAG_FRAME_SIZE,
        n_mels=hparams.SYNTHESIZER_N_MELS,
        output_per_step=hparams.OUTPUT_PER_STEP,
        learning_rate=hparams.LEARNING_RATE,
        clipnorm=hparams.CLIPNORM,
        enc_seq_len=None,
        dec_seq_len=None):
    char_inputs = Input(shape=(enc_seq_len, ), name='char_inputs')
    decoder_inputs = Input(shape=(dec_seq_len, n_mels), name='decoder_inputs')
    spk_embed_inputs = Input(shape=(spk_embed_size, ), name='spk_embed_inputs')

    char_encoder = Encoder(hidden_size=hidden_size // 2,
                           vocab_size=vocab_size,
                           embedding_size=char_embed_size,
                           conv1d_bank_depth=enc_conv1_bank_depth,
                           convprojec_filters1=enc_convprojec_filters1,
                           convprojec_filters2=enc_convprojec_filters2,
                           highway_depth=enc_highway_depth,
                           name='char_encoder')
    condition = Conditioning()
    decoder = Decoder(hidden_size=hidden_size,
                      attention_dim=attention_dim,
                      n_mels=n_mels,
                      output_per_step=output_per_step,
                      name='decoder')
    post_processing = PostProcessing(
        hidden_size=hidden_size // 2,
        conv1d_bank_depth=post_conv1_bank_depth,
        convprojec_filters1=post_convprojec_filters1,
        convprojec_filters2=post_convprojec_filters2,
        highway_depth=post_highway_depth,
        n_fft=target_size,
        name='postprocessing')

    char_enc = char_encoder(char_inputs)
    conditioned_char_enc = condition([char_enc, spk_embed_inputs])
    decoder_pred, alignments = decoder([conditioned_char_enc, decoder_inputs],
                                       initial_state=None)
    postnet_out = post_processing(decoder_pred)

    synthesizer_model = Model(
        inputs=[char_inputs, spk_embed_inputs, decoder_inputs],
        outputs=[decoder_pred, postnet_out, alignments])
    optimizer = Adam(lr=learning_rate, clipnorm=clipnorm)
    synthesizer_model.compile(optimizer=optimizer,
                              loss=['mae', 'mae', None],
                              loss_weights=[1., 1., None])

    return synthesizer_model
示例#2
0
def main():
    parser = argparse.ArgumentParser()

    default_dataset = 'toy-data.npz'
    parser.add_argument('--data', default=default_dataset, help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=1000,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=8e-5,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=1e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=5e-5,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=7e-5,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=0, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--no-norm-trans',
                        action='store_true',
                        help='if set, use Gaussian posterior without '
                        'transformation')
    parser.add_argument('--plot-interval',
                        type=int,
                        default=1,
                        help='plot interval. 0 to disable plotting.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=.2,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='smooth_l1',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=-1,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    # squash is off when rescale is off
    parser.add_argument('--squash',
                        dest='squash',
                        action='store_const',
                        const=True,
                        default=True,
                        help='bound the generated time series value '
                        'using tanh')
    parser.add_argument('--no-squash',
                        dest='squash',
                        action='store_const',
                        const=False)

    # rescale to [-1, 1]
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)

    args = parser.parse_args()

    batch_size = args.batch_size
    nz = args.nz

    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval

    try:
        npz = np.load(args.data)
        train_data = npz['data']
        train_time = npz['time']
        train_mask = npz['mask']
    except FileNotFoundError:
        if args.data != default_dataset:
            raise
        # Generate the default toy dataset from scratch
        train_data, train_time, train_mask, _, _ = gen_data(
            n_samples=10000,
            seq_len=200,
            max_time=1,
            poisson_rate=50,
            obs_span_rate=.25,
            save_file=default_dataset)

    _, in_channels, seq_len = train_data.shape
    train_time *= train_mask

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Scale time
    max_time = 5
    train_time *= max_time

    squash = None
    rescaler = None
    if args.rescale:
        rescaler = Rescaler(train_data)
        train_data = rescaler.rescale(train_data)
        if args.squash:
            squash = torch.tanh

    out_channels = 64
    cconv_ref = 98

    train_dataset = TimeSeries(train_data,
                               train_time,
                               train_mask,
                               label=None,
                               max_time=max_time,
                               cconv_ref=cconv_ref,
                               overlap_rate=args.overlap,
                               device=device)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    test_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             collate_fn=train_dataset.collate_fn)

    grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
    decoder = Decoder(grid_decoder, max_time=max_time).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=args.overlap,
                             kernel_size=args.comp,
                             norm=True).to(device)
    encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device)

    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=args.overlap,
                                    kernel_size=args.comp,
                                    norm=True).to(device)
    critic = ConvCritic(critic_cconv, nz).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'toy-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)

    tracker = Tracker(log_dir, n_train_batch)
    visualizer = Visualizer(encoder, decoder, batch_size, max_time,
                            test_loader, rescaler, output_dir, device)
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)

        for ((val, idx, mask, _, cconv_graph),
             (_, idx_t, mask_t, index, _)) in zip(train_loader, time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan(
                val, idx, mask, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward(retain_graph=True)
            critic_optimizer.step()

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            if ema:
                ema.apply()
                visualizer.plot(epoch)
                ema.restore()
            else:
                visualizer.plot(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
示例#3
0
    def __init__(self,
                 encoder_1,
                 hidden_1,
                 Z_DIMS,
                 decoder_share,
                 share_hidden,
                 decoder_1,
                 hidden_2,
                 encoder_l,
                 hidden3,
                 encoder_2,
                 hidden_4,
                 encoder_l1,
                 hidden3_1,
                 decoder_2,
                 hidden_5,
                 drop_rate,
                 log_variational=True,
                 Type='Bernoulli',
                 device='cpu',
                 n_centroids=19,
                 penality="GMM",
                 model=2):

        super(scMVAE_POE, self).__init__()

        self.X1_encoder = Encoder(encoder_1,
                                  hidden_1,
                                  Z_DIMS,
                                  dropout_rate=drop_rate)
        self.X1_encoder_l = Encoder(encoder_l,
                                    hidden3,
                                    1,
                                    dropout_rate=drop_rate)

        self.X1_decoder = Decoder_ZINB(decoder_1,
                                       hidden_2,
                                       encoder_1[0],
                                       dropout_rate=drop_rate)

        self.X2_encoder = Encoder(encoder_2,
                                  hidden_4,
                                  Z_DIMS,
                                  dropout_rate=drop_rate)

        self.decode_share = build_multi_layers(decoder_share,
                                               dropout_rate=drop_rate)

        if Type == 'ZINB':
            self.X2_encoder_l = Encoder(encoder_l1,
                                        hidden3_1,
                                        1,
                                        dropout_rate=drop_rate)
            self.decoder_x2 = Decoder_ZINB(decoder_2,
                                           hidden_5,
                                           encoder_2[0],
                                           dropout_rate=drop_rate)
        elif Type == 'Bernoulli':
            self.decoder_x2 = Decoder(decoder_2,
                                      hidden_5,
                                      encoder_2[0],
                                      Type,
                                      dropout_rate=drop_rate)
        elif Type == "Possion":
            self.decoder_x2 = Decoder(decoder_2,
                                      hidden_5,
                                      encoder_2[0],
                                      Type,
                                      dropout_rate=drop_rate)
        else:
            self.decoder_x2 = Decoder(decoder_2,
                                      hidden_5,
                                      encoder_2[0],
                                      Type,
                                      dropout_rate=drop_rate)

        self.experts = ProductOfExperts()
        self.Z_DIMS = Z_DIMS
        self.share_hidden = share_hidden
        self.log_variational = log_variational
        self.Type = Type
        self.decoder_share = decoder_share
        self.decoder_1 = decoder_1
        self.n_centroids = n_centroids
        self.penality = penality
        self.device = device
        self.model = model

        self.pi = nn.Parameter(torch.ones(n_centroids) / n_centroids)  # pc
        self.mu_c = nn.Parameter(torch.zeros(Z_DIMS, n_centroids))  # mu
        self.var_c = nn.Parameter(torch.ones(Z_DIMS, n_centroids))  # sigma^2
示例#4
0
def get_full_model(vocab_size=len(hparams.VOCAB),
                   char_embed_size=hparams.CHAR_EMBED_SIZE,
                   sliding_window_size=hparams.SLIDING_WINDOW_SIZE,
                   spk_embed_lstm_units=hparams.SPK_EMBED_LSTM_UNITS,
                   spk_embed_size=hparams.SPK_EMBED_SIZE,
                   spk_embed_num_layers=hparams.SPK_EMBED_NUM_LAYERS,
                   enc_conv1_bank_depth=hparams.ENC_CONV1_BANK_DEPTH,
                   enc_convprojec_filters1=hparams.ENC_CONVPROJEC_FILTERS1,
                   enc_convprojec_filters2=hparams.ENC_CONVPROJEC_FILTERS2,
                   enc_highway_depth=hparams.ENC_HIGHWAY_DEPTH,
                   hidden_size=hparams.HIDDEN_SIZE,
                   post_conv1_bank_depth=hparams.POST_CONV1_BANK_DEPTH,
                   post_convprojec_filters1=hparams.POST_CONVPROJEC_FILTERS1,
                   post_convprojec_filters2=hparams.POST_CONVPROJEC_FILTERS2,
                   post_highway_depth=hparams.POST_HIGHWAY_DEPTH,
                   attention_dim=hparams.ATTENTION_DIM,
                   target_size=hparams.TARGET_MAG_FRAME_SIZE,
                   n_mels=hparams.SYNTHESIZER_N_MELS,
                   output_per_step=hparams.OUTPUT_PER_STEP,
                   embed_mels=hparams.SPK_EMBED_N_MELS,
                   enc_seq_len=None,
                   dec_seq_len=None):
    char_inputs = Input(shape=(enc_seq_len, ), name='char_inputs')
    decoder_inputs = Input(shape=(dec_seq_len, n_mels), name='decoder_inputs')
    spk_inputs = Input(shape=(None, sliding_window_size, embed_mels),
                       name='spk_embed_inputs')

    char_encoder = Encoder(hidden_size=hidden_size // 2,
                           vocab_size=vocab_size,
                           embedding_size=char_embed_size,
                           conv1d_bank_depth=enc_conv1_bank_depth,
                           convprojec_filters1=enc_convprojec_filters1,
                           convprojec_filters2=enc_convprojec_filters2,
                           highway_depth=enc_highway_depth,
                           name='char_encoder')
    speaker_encoder = InferenceSpeakerEmbedding(
        lstm_units=spk_embed_lstm_units,
        proj_size=spk_embed_size,
        num_layers=spk_embed_num_layers,
        trainable=False,
        name='embeddings')
    condition = Conditioning()
    decoder = Decoder(hidden_size=hidden_size,
                      attention_dim=attention_dim,
                      n_mels=n_mels,
                      output_per_step=output_per_step,
                      name='decoder')
    post_processing = PostProcessing(
        hidden_size=hidden_size // 2,
        conv1d_bank_depth=post_conv1_bank_depth,
        convprojec_filters1=post_convprojec_filters1,
        convprojec_filters2=post_convprojec_filters2,
        highway_depth=post_highway_depth,
        n_fft=target_size,
        name='postprocessing')

    char_enc = char_encoder(char_inputs)
    spk_embed = speaker_encoder(spk_inputs)
    conditioned_char_enc = condition([char_enc, spk_embed])
    decoder_pred, alignments = decoder([conditioned_char_enc, decoder_inputs],
                                       initial_state=None)
    postnet_out = post_processing(decoder_pred)

    full_model = Model(
        inputs=[char_inputs, spk_inputs, decoder_inputs],
        outputs=[decoder_pred, postnet_out, alignments, spk_embed])
    return full_model
    def __init__(self, dx=2, dh=128, dff=512, N=3, M=8, samples=1280):
        super(AttentionModel, self).__init__()

        self.samples = samples
        self.encoder = Encoder(dx, dh, dff, N, N)
        self.decoder = Decoder(dx, dh, M, samples)
示例#6
0
    def __init__(self,
                 encoder_1,
                 encoder_2,
                 share_e,
                 hidden,
                 zdim,
                 share_d,
                 decoder_1,
                 hidden1,
                 decoder_2,
                 hidden2,
                 encoder_l,
                 hidden_l,
                 encoder_l1,
                 hidden_l1,
                 laste_hidden,
                 logvariantional=True,
                 drop_rate=0.1,
                 drop_rate_d=0.1,
                 Type1='ZINB',
                 Type='ZINB',
                 pair=False,
                 mode=0,
                 library_mode=0,
                 n_centroids=19,
                 penality="GMM"):

        super(scMVAE_NN, self).__init__()

        self.encoder_x1 = build_multi_layers(encoder_1, dropout_rate=drop_rate)
        self.encoder_x2 = build_multi_layers(encoder_2, dropout_rate=drop_rate)

        self.encoder_share = Encoder(share_e,
                                     hidden,
                                     zdim,
                                     dropout_rate=drop_rate)
        self.decoder_share = build_multi_layers(share_d,
                                                dropout_rate=drop_rate_d)

        self.decoder_x1 = Decoder_ZINB(decoder_1,
                                       hidden1,
                                       encoder_1[0],
                                       dropout_rate=drop_rate_d)

        #self.decoder_x1  =  Decoder( decoder_1, hidden1, encoder_1[0], Type1, dropout_rate = drop_rate )

        if library_mode == 0:
            self.encoder_l = Encoder(encoder_l, hidden_l, 1)
        else:
            self.encoder_l = Encoder([128], encoder_1[-1], 1)

        if Type == "ZINB":
            self.encoder_l2 = Encoder(encoder_l1,
                                      hidden_l1,
                                      1,
                                      dropout_rate=drop_rate)
            self.decoder_x2 = Decoder_ZINB(decoder_2,
                                           hidden2,
                                           encoder_2[0],
                                           dropout_rate=drop_rate_d)
        else:
            self.decoder_x2 = Decoder(decoder_2,
                                      hidden2,
                                      encoder_2[0],
                                      Type,
                                      dropout_rate=drop_rate_d)

        ###parameters
        self.logvariantional = logvariantional
        self.hidden = laste_hidden
        self.Type = Type
        self.Type1 = Type1
        self.pair = pair
        self.mode = mode
        self.library_mode = library_mode
        self.n_centroids = n_centroids
        self.penality = penality

        self.pi = nn.Parameter(torch.ones(n_centroids) / n_centroids)  # pc
        self.mu_c = nn.Parameter(torch.zeros(zdim, n_centroids))  # mu
        self.var_c = nn.Parameter(torch.ones(zdim, n_centroids))  # sigma^2
class Model(nn.Module):
    """Model"""
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        self.init_embeddings()
        self.init_model()

    def init_embeddings(self):
        embed_dim = self.config['embed_dim']
        tie_mode = self.config['tie_mode']
        max_pos_length = self.config['max_pos_length']
        learned_pos = self.config['learned_pos']

        # get positonal embedding
        if not learned_pos:
            self.pos_embedding = ut.get_positional_encoding(
                embed_dim, max_pos_length)
        else:
            self.pos_embedding = Parameter(
                torch.Tensor(max_pos_length, embed_dim))
            nn.init.normal_(self.pos_embedding, mean=0, std=embed_dim**-0.5)

        # get word embeddings
        src_vocab_size, trg_vocab_size = ut.get_vocab_sizes(self.config)
        self.src_vocab_mask, self.trg_vocab_mask = ut.get_vocab_masks(
            self.config, src_vocab_size, trg_vocab_size)
        if tie_mode == ac.ALL_TIED:
            src_vocab_size = trg_vocab_size = self.trg_vocab_mask.shape[0]

        self.out_bias = Parameter(torch.Tensor(trg_vocab_size))
        nn.init.constant_(self.out_bias, 0.)

        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim)
        self.out_embedding = self.trg_embedding.weight
        self.embed_scale = embed_dim**0.5

        if tie_mode == ac.ALL_TIED:
            self.src_embedding.weight = self.trg_embedding.weight

        nn.init.normal_(self.src_embedding.weight, mean=0, std=embed_dim**-0.5)
        nn.init.normal_(self.trg_embedding.weight, mean=0, std=embed_dim**-0.5)

    def init_model(self):
        num_enc_layers = self.config['num_enc_layers']
        num_enc_heads = self.config['num_enc_heads']
        num_dec_layers = self.config['num_dec_layers']
        num_dec_heads = self.config['num_dec_heads']

        embed_dim = self.config['embed_dim']
        ff_dim = self.config['ff_dim']
        dropout = self.config['dropout']

        # get encoder, decoder
        self.encoder = Encoder(num_enc_layers,
                               num_enc_heads,
                               embed_dim,
                               ff_dim,
                               dropout=dropout)
        self.decoder = Decoder(num_dec_layers,
                               num_dec_heads,
                               embed_dim,
                               ff_dim,
                               dropout=dropout)

        # leave layer norm alone
        init_func = nn.init.xavier_normal_ if self.config[
            'init_type'] == ac.XAVIER_NORMAL else nn.init.xavier_uniform_
        for m in [
                self.encoder.self_atts, self.encoder.pos_ffs,
                self.decoder.self_atts, self.decoder.pos_ffs,
                self.decoder.enc_dec_atts
        ]:
            for p in m.parameters():
                if p.dim() > 1:
                    init_func(p)
                else:
                    nn.init.constant_(p, 0.)

    def get_input(self, toks, is_src=True):
        embeds = self.src_embedding if is_src else self.trg_embedding
        word_embeds = embeds(toks)  # [bsz, max_len, embed_dim]
        pos_embeds = self.pos_embedding[:toks.size()[-1], :].unsqueeze(
            0)  # [1, max_len, embed_dim]
        return word_embeds * self.embed_scale + pos_embeds

    def forward(self, src_toks, trg_toks, targets):
        encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze(
            2)  # [bsz, 1, 1, max_src_len]
        decoder_mask = torch.triu(torch.ones(
            (trg_toks.size()[-1], trg_toks.size()[-1])),
                                  diagonal=1).type(trg_toks.type()) == 1
        decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1)

        encoder_inputs = self.get_input(src_toks, is_src=True)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)

        decoder_inputs = self.get_input(trg_toks, is_src=False)
        decoder_outputs = self.decoder(decoder_inputs, decoder_mask,
                                       encoder_outputs, encoder_mask)

        logits = self.logit_fn(decoder_outputs)
        neglprobs = F.log_softmax(logits, -1)
        neglprobs = neglprobs * self.trg_vocab_mask.type(
            neglprobs.type()).reshape(1, -1)
        targets = targets.reshape(-1, 1)
        non_pad_mask = targets != ac.PAD_ID
        nll_loss = -neglprobs.gather(dim=-1, index=targets)[non_pad_mask]
        smooth_loss = -neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask]

        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
        label_smoothing = self.config['label_smoothing']
        loss = (
            1.0 - label_smoothing
        ) * nll_loss + label_smoothing * smooth_loss / self.trg_vocab_mask.type(
            smooth_loss.type()).sum()

        return {'loss': loss, 'nll_loss': nll_loss}

    def logit_fn(self, decoder_output):
        logits = F.linear(decoder_output,
                          self.out_embedding,
                          bias=self.out_bias)
        logits = logits.reshape(-1, logits.size()[-1])
        logits[:, ~self.trg_vocab_mask] = -1e9
        return logits

    def beam_decode(self, src_toks):
        encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze(
            2)  # [bsz, 1, 1, max_src_len]
        encoder_inputs = self.get_input(src_toks, is_src=True)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)
        max_lengths = torch.sum(src_toks != ac.PAD_ID, dim=-1).type(
            src_toks.type()) + 50

        def get_trg_inp(ids, time_step):
            ids = ids.type(src_toks.type())
            word_embeds = self.trg_embedding(ids)
            pos_embeds = self.pos_embedding[time_step, :].reshape(1, 1, -1)
            return word_embeds * self.embed_scale + pos_embeds

        def logprob(decoder_output):
            return F.log_softmax(self.logit_fn(decoder_output), dim=-1)

        return self.decoder.beam_decode(encoder_outputs,
                                        encoder_mask,
                                        get_trg_inp,
                                        logprob,
                                        ac.BOS_ID,
                                        ac.EOS_ID,
                                        max_lengths,
                                        beam_size=self.config['beam_size'],
                                        alpha=self.config['beam_alpha'])
class Transformer(nn.Module):
    """Transformer https://arxiv.org/pdf/1706.03762.pdf"""
    def __init__(self, args):
        super(Transformer, self).__init__()
        self.args = args

        embed_dim = args.embed_dim
        fix_norm = args.fix_norm
        joint_vocab_size = args.joint_vocab_size
        lang_vocab_size = args.lang_vocab_size
        use_bias = args.use_bias
        self.scale = embed_dim**0.5

        if args.mask_logit:
            # mask logits separately per language
            self.logit_mask = None
        else:
            # otherwise, use the same mask for all
            # this only masks out BOS and PAD
            mask = [True] * joint_vocab_size
            mask[ac.BOS_ID] = False
            mask[ac.PAD_ID] = False
            self.logit_mask = torch.tensor(mask).type(torch.bool)

        self.word_embedding = Parameter(
            torch.Tensor(joint_vocab_size, embed_dim))
        self.lang_embedding = Parameter(
            torch.Tensor(lang_vocab_size, embed_dim))
        self.out_bias = Parameter(
            torch.Tensor(joint_vocab_size)) if use_bias else None

        self.encoder = Encoder(args)
        self.decoder = Decoder(args)

        # initialize
        nn.init.normal_(self.lang_embedding, mean=0, std=embed_dim**-0.5)
        if fix_norm:
            d = 0.01
            nn.init.uniform_(self.word_embedding, a=-d, b=d)
        else:
            nn.init.normal_(self.word_embedding, mean=0, std=embed_dim**-0.5)

        if use_bias:
            nn.init.constant_(self.out_bias, 0.)

    def replace_with_unk(self, toks):
        # word-dropout
        p = self.args.word_dropout
        if self.training and 0 < p < 1:
            non_pad_mask = toks != ac.PAD_ID
            mask = (torch.rand(toks.size()) <= p).type(non_pad_mask.type())
            mask = mask & non_pad_mask
            toks[mask] = ac.UNK_ID

    def get_input(self, toks, lang_idx, word_embedding, pos_embedding):
        # word dropout, but replace with unk instead of zero-ing embed
        self.replace_with_unk(toks)
        word_embed = F.embedding(
            toks, word_embedding) * self.scale  # [bsz, len, dim]
        lang_embed = self.lang_embedding[lang_idx].unsqueeze(0).unsqueeze(
            1)  # [1, 1, dim]
        pos_embed = pos_embedding[:toks.size(-1), :].unsqueeze(
            0)  # [1, len, dim]

        return word_embed + lang_embed + pos_embed

    def forward(self, src, tgt, targets, src_lang_idx, tgt_lang_idx,
                logit_mask):
        embed_dim = self.args.embed_dim
        max_len = max(src.size(1), tgt.size(1))
        pos_embedding = ut.get_positional_encoding(embed_dim, max_len)
        word_embedding = F.normalize(
            self.word_embedding,
            dim=-1) if self.args.fix_norm else self.word_embedding

        encoder_inputs = self.get_input(src, src_lang_idx, word_embedding,
                                        pos_embedding)
        encoder_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)

        decoder_inputs = self.get_input(tgt, tgt_lang_idx, word_embedding,
                                        pos_embedding)
        decoder_mask = torch.triu(torch.ones((tgt.size(-1), tgt.size(-1))),
                                  diagonal=1).type(tgt.type()) == 1
        decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1)
        decoder_outputs = self.decoder(decoder_inputs, decoder_mask,
                                       encoder_outputs, encoder_mask)

        logit_mask = logit_mask == 1 if self.logit_mask is None else self.logit_mask
        logits = self.logit_fn(decoder_outputs, word_embedding, logit_mask)
        neglprobs = F.log_softmax(logits, -1) * logit_mask.type(
            logits.type()).reshape(1, -1)
        targets = targets.reshape(-1, 1)
        non_pad_mask = targets != ac.PAD_ID

        nll_loss = neglprobs.gather(dim=-1, index=targets)[non_pad_mask]
        smooth_loss = neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask]

        # label smoothing: https://arxiv.org/pdf/1701.06548.pdf
        nll_loss = -(nll_loss.sum())
        smooth_loss = -(smooth_loss.sum())
        label_smoothing = self.args.label_smoothing
        if label_smoothing > 0:
            loss = (
                1.0 - label_smoothing
            ) * nll_loss + label_smoothing * smooth_loss / logit_mask.type(
                nll_loss.type()).sum()
        else:
            loss = nll_loss

        num_words = non_pad_mask.type(loss.type()).sum()
        opt_loss = loss / num_words
        return {
            'opt_loss': opt_loss,
            'loss': loss,
            'nll_loss': nll_loss,
            'num_words': num_words
        }

    def logit_fn(self, decoder_output, softmax_weight, logit_mask):
        logits = F.linear(decoder_output, softmax_weight, bias=self.out_bias)
        logits = logits.reshape(-1, logits.size(-1))
        logits[:, ~logit_mask] = -1e9
        return logits

    def beam_decode(self, src, src_lang_idx, tgt_lang_idx, logit_mask):
        embed_dim = self.args.embed_dim
        max_len = src.size(1) + 51
        pos_embedding = ut.get_positional_encoding(embed_dim, max_len)
        word_embedding = F.normalize(
            self.word_embedding,
            dim=-1) if self.args.fix_norm else self.word_embedding
        logit_mask = logit_mask == 1 if self.logit_mask is None else self.logit_mask
        tgt_lang_embed = self.lang_embedding[tgt_lang_idx]

        encoder_inputs = self.get_input(src, src_lang_idx, word_embedding,
                                        pos_embedding)
        encoder_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)

        def get_tgt_inp(tgt, time_step):
            word_embed = F.embedding(tgt.type(src.type()),
                                     word_embedding) * self.scale
            pos_embed = pos_embedding[time_step, :].reshape(1, 1, -1)
            return word_embed + tgt_lang_embed + pos_embed

        def logprob_fn(decoder_output):
            logits = self.logit_fn(decoder_output, word_embedding, logit_mask)
            return F.log_softmax(logits, dim=-1)

        # following Attention is all you need, we decode up to src_len + 50 tokens only
        max_lengths = torch.sum(src != ac.PAD_ID, dim=-1).type(src.type()) + 50
        return self.decoder.beam_decode(encoder_outputs,
                                        encoder_mask,
                                        get_tgt_inp,
                                        logprob_fn,
                                        ac.BOS_ID,
                                        ac.EOS_ID,
                                        max_lengths,
                                        beam_size=self.args.beam_size,
                                        alpha=self.args.beam_alpha)
def main(num_epochs=10,
         k=100,
         batch_size=128,
         display_freq=100,
         save_freq=1000,
         load_previous=False,
         attention=True,
         word_by_word=True,
         p=0,
         mode='word_by_word'):
    print('num_epochs: {}'.format(num_epochs))
    print('k: {}'.format(k))
    print('batch_size: {}'.format(batch_size))
    print('display_frequency: {}'.format(display_freq))
    print('save_frequency: {}'.format(save_freq))
    print('load previous: {}'.format(load_previous))
    print('attention: {}'.format(attention))
    print('word_by_word: {}'.format(word_by_word))
    save_filename = './snli/{}_model.npz'.format(mode)
    print("Building network ...")
    premise_var = T.imatrix('premise_var')
    premise_mask = T.imatrix('premise_mask')
    hypo_var = T.imatrix('hypo_var')
    hypo_mask = T.imatrix('hypo_mask')
    unchanged_W = pickle.load(open('./snli/unchanged_W.pkl', 'rb'))
    unchanged_W = unchanged_W.astype('float32')
    unchanged_W_shape = unchanged_W.shape
    oov_in_train_W = pickle.load(open('./snli/oov_in_train_W.pkl', 'rb'))
    oov_in_train_W = oov_in_train_W.astype('float32')
    oov_in_train_W_shape = oov_in_train_W.shape
    print('unchanged_W.shape: {0}'.format(unchanged_W_shape))
    print('oov_in_train_W.shape: {0}'.format(oov_in_train_W_shape))
    # hyperparameters
    learning_rate = 0.001
    l2_weight = 0.
    #Input layers
    l_premise = lasagne.layers.InputLayer(shape=(None, premise_max),
                                          input_var=premise_var)
    l_premise_mask = lasagne.layers.InputLayer(shape=(None, premise_max),
                                               input_var=premise_mask)
    l_hypo = lasagne.layers.InputLayer(shape=(None, hypothesis_max),
                                       input_var=hypo_var)
    l_hypo_mask = lasagne.layers.InputLayer(shape=(None, hypothesis_max),
                                            input_var=hypo_mask)
    #Embedded layers
    premise_embedding = EmbeddedLayer(l_premise,
                                      unchanged_W,
                                      unchanged_W_shape,
                                      oov_in_train_W,
                                      oov_in_train_W_shape,
                                      p=p)
    #weights shared with premise_embedding
    hypo_embedding = EmbeddedLayer(
        l_hypo,
        unchanged_W=premise_embedding.unchanged_W,
        unchanged_W_shape=unchanged_W_shape,
        oov_in_train_W=premise_embedding.oov_in_train_W,
        oov_in_train_W_shape=oov_in_train_W_shape,
        p=p,
        dropout_mask=premise_embedding.dropout_mask)
    #Dense layers
    l_premise_linear = DenseLayer(premise_embedding,
                                  k,
                                  nonlinearity=lasagne.nonlinearities.linear)
    l_hypo_linear = DenseLayer(hypo_embedding,
                               k,
                               W=l_premise_linear.W,
                               b=l_premise_linear.b,
                               nonlinearity=lasagne.nonlinearities.linear)

    encoder = Encoder(l_premise_linear,
                      k,
                      peepholes=False,
                      mask_input=l_premise_mask)
    #initialized with encoder final hidden state
    decoder = Decoder(l_hypo_linear,
                      k,
                      cell_init=encoder,
                      peepholes=False,
                      mask_input=l_hypo_mask,
                      encoder_mask_input=l_premise_mask,
                      attention=attention,
                      word_by_word=word_by_word)
    if p > 0.:
        print('apply dropout rate {} to decoder'.format(p))
        decoder = lasagne.layers.DropoutLayer(decoder, p)
    l_softmax = lasagne.layers.DenseLayer(
        decoder, num_units=3, nonlinearity=lasagne.nonlinearities.softmax)
    target_var = T.ivector('target_var')

    #lasagne.layers.get_output produces a variable for the output of the net
    prediction = lasagne.layers.get_output(l_softmax, deterministic=False)
    #The network output will have shape (n_batch, 3);
    loss = lasagne.objectives.categorical_crossentropy(prediction, target_var)
    cost = loss.mean()
    if l2_weight > 0.:
        #apply l2 regularization
        print('apply l2 penalty to all layers, weight: {}'.format(l2_weight))
        regularized_layers = {encoder: l2_weight, decoder: l2_weight}
        l2_penalty = lasagne.regularization.regularize_network_params(
            l_softmax, lasagne.regularization.l2) * l2_weight
        cost += l2_penalty


#Retrieve all parameters from the network
    all_params = lasagne.layers.get_all_params(l_softmax, trainable=True)
    #Compute adam updates for training
    print("Computing updates ...")
    updates = lasagne.updates.adam(cost,
                                   all_params,
                                   learning_rate=learning_rate)

    test_prediction = lasagne.layers.get_output(l_softmax, deterministic=True)
    test_loss = lasagne.objectives.categorical_crossentropy(
        test_prediction, target_var)
    test_loss = test_loss.mean()
    # lasagne.objectives.categorical_accuracy()
    # As a bonus, also create an expression for the classification accuracy:
    test_acc = T.mean(T.eq(T.argmax(test_prediction, axis=1), target_var),
                      dtype=theano.config.floatX)

    # Theano functions for training and computing cost
    print("Compiling functions ...")
    train_fn = theano.function(
        [premise_var, premise_mask, hypo_var, hypo_mask, target_var],
        cost,
        updates=updates)
    val_fn = theano.function(
        [premise_var, premise_mask, hypo_var, hypo_mask, target_var],
        [test_loss, test_acc])
    print("Training ...")

    print('train_data.shape: {0}'.format(train_data.shape))
    print('val_data.shape: {0}'.format(val_data.shape))
    print('test_data.shape: {0}'.format(test_data.shape))
    try:
        # Finally, launch the training loop.
        print("Training started...")
        # iterate over epochs:
        for epoch in range(num_epochs):
            # In each epoch, do a full pass over the training data:
            shuffled_train_data = train_data.reindex(
                np.random.permutation(train_data.index))
            train_err = 0
            train_acc = 0
            train_batches = 0
            start_time = time.time()
            display_at = time.time()
            save_at = time.time()
            for start_i in range(0, len(shuffled_train_data), batch_size):
                batched_data = shuffled_train_data[start_i:start_i +
                                                   batch_size]
                ps, p_masks, hs, h_masks, labels = prepare(batched_data)
                train_err += train_fn(ps, p_masks, hs, h_masks, labels)
                err, acc = val_fn(ps, p_masks, hs, h_masks, labels)
                train_acc += acc
                train_batches += 1
                # display
                if train_batches % display_freq == 0:
                    print("Seen {:d} samples, time used: {:.3f}s".format(
                        start_i + batch_size,
                        time.time() - display_at))
                    print("  current training loss:\t\t{:.6f}".format(
                        train_err / train_batches))
                    print("  current training accuracy:\t\t{:.6f}".format(
                        train_acc / train_batches))
                # do tmp save model
                if train_batches % save_freq == 0:
                    print(
                        'saving to ..., time used {:.3f}s'.format(time.time() -
                                                                  save_at))
                    np.savez(save_filename,
                             *lasagne.layers.get_all_param_values(l_softmax))
                    save_at = time.time()

            # And a full pass over the validation data:
            val_err = 0
            val_acc = 0
            val_batches = 0
            for start_i in range(0, len(val_data), batch_size):
                batched_data = val_data[start_i:start_i + batch_size]
                ps, p_masks, hs, h_masks, labels = prepare(batched_data)
                err, acc = val_fn(ps, p_masks, hs, h_masks, labels)
                val_err += err
                val_acc += acc
                val_batches += 1

            # Then we print the results for this epoch:
            print("Epoch {} of {} took {:.3f}s".format(
                epoch + 1, num_epochs,
                time.time() - start_time))
            print("  training loss:\t\t{:.6f}".format(train_err /
                                                      train_batches))
            print("  training accuracy:\t\t{:.2f} %".format(
                train_acc / train_batches * 100))
            print("  validation loss:\t\t{:.6f}".format(val_err / val_batches))
            print("  validation accuracy:\t\t{:.2f} %".format(
                val_acc / val_batches * 100))

            # After training, we compute and print the test error:
            test_err = 0
            test_acc = 0
            test_batches = 0
            for start_i in range(0, len(test_data), batch_size):
                batched_data = test_data[start_i:start_i + batch_size]
                ps, p_masks, hs, h_masks, labels = prepare(batched_data)
                err, acc = val_fn(ps, p_masks, hs, h_masks, labels)
                test_err += err
                test_acc += acc
                test_batches += 1
            # print("Final results:")
            print("  test loss:\t\t\t{:.6f}".format(test_err / test_batches))
            print("  test accuracy:\t\t{:.2f} %".format(test_acc /
                                                        test_batches * 100))
            filename = './snli/{}_model_epoch{}.npz'.format(mode, epoch + 1)
            print('saving to {}'.format(filename))
            np.savez(filename, *lasagne.layers.get_all_param_values(l_softmax))

        # Optionally, you could now dump the network weights to a file like this:
        # np.savez('model.npz', *lasagne.layers.get_all_param_values(network))
        #
        # And load them again later on like this:
        # with np.load('model.npz') as f:
        #     param_values = [f['arr_%d' % i] for i in range(len(f.files))]
        # lasagne.layers.set_all_param_values(network, param_values)
    except KeyboardInterrupt:
        print('exit ...')
示例#10
0
    def __init__(self, hps, device):
        super(WorkingMemoryModel, self).__init__()
        self.hps = hps
        self.device = device

        self.global_trace_size = hps.global_trace_size
        self.topic_trace_size = hps.topic_trace_size
        self.topic_slots = hps.topic_slots
        self.his_mem_slots = hps.his_mem_slots

        self.vocab_size = hps.vocab_size
        self.mem_size = hps.mem_size

        self.sens_num = hps.sens_num

        self.pad_idx = hps.pad_idx
        self.bos_tensor = torch.tensor(hps.bos_idx, dtype=torch.long, device=device)

        # ----------------------------
        # build componets
        self.layers = nn.ModuleDict()
        self.layers['word_embed'] = nn.Embedding(hps.vocab_size,
            hps.word_emb_size, padding_idx=hps.pad_idx)

        # NOTE: We set fixed 33 phonology categories: 0~32
        #   please refer to preprocess.py for more details
        self.layers['ph_embed'] = nn.Embedding(33, hps.ph_emb_size)

        self.layers['len_embed'] = nn.Embedding(hps.sen_len, hps.len_emb_size)


        self.layers['encoder'] = BidirEncoder(hps.word_emb_size, hps.hidden_size, drop_ratio=hps.drop_ratio)
        self.layers['decoder'] = Decoder(hps.hidden_size, hps.hidden_size, drop_ratio=hps.drop_ratio)

        # project the decoder hidden state to a vocanbulary-size output logit
        self.layers['out_proj'] = nn.Linear(hps.hidden_size, hps.vocab_size)

        # update the context vector
        self.layers['global_trace_updater'] = ContextLayer(hps.hidden_size, hps.global_trace_size)
        self.layers['topic_trace_updater'] = MLP(self.mem_size+self.topic_trace_size,
            layer_sizes=[self.topic_trace_size], activs=['tanh'], drop_ratio=hps.drop_ratio)


        # MLP for calculate initial decoder state
        self.layers['dec_init'] = MLP(hps.hidden_size*2, layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)
        self.layers['key_init'] = MLP(hps.hidden_size*2, layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)

        # history memory reading and writing layers
        # query: concatenation of hidden state, global_trace and topic_trace
        self.layers['memory_read'] = AttentionReader(
            d_q=hps.hidden_size+self.global_trace_size+self.topic_trace_size+self.topic_slots,
            d_v=hps.mem_size, drop_ratio=hps.attn_drop_ratio)

        self.layers['memory_write'] = AttentionWriter(hps.mem_size+self.global_trace_size, hps.mem_size)

        # NOTE: a layer to compress the encoder hidden states to a smaller size for larger number of slots
        self.layers['mem_compress'] = MLP(hps.hidden_size*2, layer_sizes=[hps.mem_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)

        # [inp, attns, ph_inp, len_inp, global_trace]
        self.layers['merge_x'] = MLP(
            hps.word_emb_size+hps.ph_emb_size+hps.len_emb_size+hps.global_trace_size+hps.mem_size,
            layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)


        # two annealing parameters
        self._tau = 1.0
        self._teach_ratio = 0.8


        # ---------------------------------------------------------
        # only used for for pre-training
        self.layers['dec_init_pre'] = MLP(hps.hidden_size*2,
            layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)

        self.layers['merge_x_pre'] = MLP(
            hps.word_emb_size+hps.ph_emb_size+hps.len_emb_size,
            layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)
示例#11
0
    def __init__(self, args, tokenizer) -> None:
        super().__init__(args, tokenizer)

        self.decoder = Decoder(args, len(tokenizer))
        self.criterion = nn.CrossEntropyLoss()
示例#12
0
class Seq2Seq(pl.LightningModule):
    def __init__(
            self,
            src_lang,
            trg_lang,
            max_len=32,
            hid_dim=256,
            enc_layers=3,
            dec_layers=3,
            enc_heads=8,
            dec_heads=8,
            enc_pf_dim=512,
            dec_pf_dim=512,
            enc_dropout=0.1,
            dec_dropout=0.1,
            lr=0.0005,
            **kwargs,  # throwaway
    ):
        super().__init__()

        self.save_hyperparameters()
        del self.hparams["src_lang"]
        del self.hparams["trg_lang"]

        self.src_lang = src_lang
        self.trg_lang = trg_lang

        self.encoder = Encoder(
            src_lang.n_words,
            hid_dim,
            enc_layers,
            enc_heads,
            enc_pf_dim,
            enc_dropout,
            device,
        )

        self.decoder = Decoder(
            trg_lang.n_words,
            hid_dim,
            dec_layers,
            dec_heads,
            dec_pf_dim,
            dec_dropout,
            device,
        )

        self.criterion = nn.CrossEntropyLoss(
            ignore_index=self.trg_lang.PAD_idx)
        self.initialize_weights()
        self.to(device)

    def initialize_weights(self):
        def _initialize_weights(m):
            if hasattr(m, "weight") and m.weight.dim() > 1:
                nn.init.xavier_uniform_(m.weight.data)

        self.encoder.apply(_initialize_weights)
        self.decoder.apply(_initialize_weights)

    def make_src_mask(self, src):

        # src = [batch size, src len]

        src_mask = (src != self.src_lang.PAD_idx).unsqueeze(1).unsqueeze(2)

        # src_mask = [batch size, 1, 1, src len]

        return src_mask

    def make_trg_mask(self, trg):

        # trg = [batch size, trg len]

        trg_pad_mask = (trg != self.trg_lang.PAD_idx).unsqueeze(1).unsqueeze(2)

        # trg_pad_mask = [batch size, 1, 1, trg len]

        trg_len = trg.shape[1]

        trg_sub_mask = torch.tril(torch.ones(
            (trg_len, trg_len)).type_as(trg)).bool()

        # trg_sub_mask = [trg len, trg len]

        trg_mask = trg_pad_mask & trg_sub_mask

        # trg_mask = [batch size, 1, trg len, trg len]

        return trg_mask

    def forward(self, src, trg):

        # src = [batch size, src len]
        # trg = [batch size, trg len]

        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        # src_mask = [batch size, 1, 1, src len]
        # trg_mask = [batch size, 1, trg len, trg len]

        enc_src = self.encoder(src, src_mask)

        # enc_src = [batch size, src len, hid dim]

        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)

        # output = [batch size, trg len, output dim]
        # attention = [batch size, n heads, trg len, src len]

        return output, attention

    def predict(self, sentences, batch_size=128):
        """Efficiently predict a list of sentences"""
        pred_tensors = [
            sentence_to_tensor(sentence, self.src_lang)
            for sentence in tqdm(sentences, desc="creating prediction tensors")
        ]

        collate_fn = Collater(self.src_lang, predict=True)
        pred_dataloader = DataLoader(
            SimpleDataset(pred_tensors),
            batch_size=batch_size,
            collate_fn=collate_fn,
        )

        sentences = []
        words = []
        attention = []
        for batch in tqdm(pred_dataloader, desc="predict batch num"):
            preds = self.predict_batch(batch.to(device))
            pred_sentences, pred_words, pred_attention = preds
            sentences.extend(pred_sentences)
            words.extend(pred_words)
            attention.extend(pred_attention)

        # sentences = [num pred sentences]
        # words = [num pred sentences, trg len]
        # attention = [num pred sentences, n heads, trg len, src len]

        return sentences, words, attention

    def predict_batch(self, batch):
        """Predicts on a batch of src_tensors."""
        # batch = src_tensor when predicting = [batch_size, src len]

        src_tensor = batch
        src_mask = self.make_src_mask(batch)

        # src_mask = [batch size, 1, 1, src len]

        enc_src = self.encoder(src_tensor, src_mask)

        # enc_src = [batch size, src len, hid dim]

        trg_indexes = [[self.trg_lang.SOS_idx] for _ in range(len(batch))]

        # trg_indexes = [batch_size, cur trg len = 1]

        trg_tensor = torch.LongTensor(trg_indexes).to(self.device)

        # trg_tensor = [batch_size, cur trg len = 1]
        # cur trg len increases during the for loop up to the max len

        for _ in range(self.hparams.max_len):

            trg_mask = self.make_trg_mask(trg_tensor)

            # trg_mask = [batch size, 1, cur trg len, cur trg len]

            output, attention = self.decoder(trg_tensor, enc_src, trg_mask,
                                             src_mask)

            # output = [batch size, cur trg len, output dim]

            preds = output.argmax(2)[:, -1].reshape(-1, 1)

            # preds = [batch_size, 1]

            trg_tensor = torch.cat((trg_tensor, preds), dim=-1)

            # trg_tensor = [batch_size, cur trg len], cur trg len increased by 1

        src_tensor = src_tensor.detach().cpu().numpy()
        trg_tensor = trg_tensor.detach().cpu().numpy()
        attention = attention.detach().cpu().numpy()

        pred_words = []
        pred_sentences = []
        pred_attention = []
        for src_indexes, trg_indexes, attn in zip(src_tensor, trg_tensor,
                                                  attention):
            # trg_indexes = [trg len = max len (filled with eos if max len not needed)]
            # src_indexes = [src len = len of longest sentence (padded if not longest)]

            # indexes where first eos tokens appear
            src_eosi = np.where(src_indexes == self.src_lang.EOS_idx)[0][0]
            _trg_eosi_arr = np.where(trg_indexes == self.trg_lang.EOS_idx)[0]
            if len(_trg_eosi_arr) > 0:  # check that an eos token exists in trg
                trg_eosi = _trg_eosi_arr[0]
            else:
                trg_eosi = len(trg_indexes)

            # cut target indexes up to first eos token and also exclude sos token
            trg_indexes = trg_indexes[1:trg_eosi]

            # attn = [n heads, trg len=max len, src len=max len of sentence in batch]
            # we want to keep n heads, but we'll cut trg len and src len up to
            # their first eos token
            attn = attn[:, :trg_eosi, :
                        src_eosi]  # cut attention for trg eos tokens

            words = [self.trg_lang.index2word[index] for index in trg_indexes]
            sentence = self.trg_lang.words_to_sentence(words)
            pred_words.append(words)
            pred_sentences.append(sentence)
            pred_attention.append(attn)

        # pred_sentences = [batch_size]
        # pred_words = [batch_size, trg len]
        # attention = [batch size, n heads, trg len (varies), src len (varies)]

        return pred_sentences, pred_words, pred_attention

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

    def training_step(self, batch, batch_idx):
        src, trg = batch

        output, _ = self(src, trg[:, :-1])

        # output = [batch size, trg len - 1, output dim]
        # trg = [batch size, trg len]

        output_dim = output.shape[-1]

        output = output.contiguous().view(-1, output_dim)
        trg = trg[:, 1:].contiguous().view(-1)

        # output = [batch size * trg len - 1, output dim]
        # trg = [batch size * trg len - 1]

        loss = self.criterion(output, trg)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        src, trg = batch

        output, _ = self(src, trg[:, :-1])

        # output = [batch size, trg len - 1, output dim]
        # trg = [batch size, trg len]

        output_dim = output.shape[-1]

        output = output.contiguous().view(-1, output_dim)
        trg = trg[:, 1:].contiguous().view(-1)

        # output = [batch size * trg len - 1, output dim]
        # trg = [batch size * trg len - 1]

        loss = self.criterion(output, trg)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    @staticmethod
    def add_model_specific_args(parent_parser):
        _parser = argparse.ArgumentParser(parents=[parent_parser],
                                          add_help=False)
        _parser.add_argument("--max_len", type=int, default=32)
        _parser.add_argument("--hid_dim", type=int, default=256)
        _parser.add_argument("--enc_layers", type=int, default=3)
        _parser.add_argument("--dec_layers", type=int, default=3)
        _parser.add_argument("--enc_heads", type=int, default=8)
        _parser.add_argument("--dec_heads", type=int, default=8)
        _parser.add_argument("--enc_pf_dim", type=int, default=512)
        _parser.add_argument("--dec_pf_dim", type=int, default=512)
        _parser.add_argument("--enc_dropout", type=float, default=0.1)
        _parser.add_argument("--dec_dropout", type=float, default=0.1)
        _parser.add_argument("--lr", type=float, default=0.0005)
        return _parser
def main():
    parser = argparse.ArgumentParser()

    default_dataset = 'toy-data.npz'
    parser.add_argument('--data', default=default_dataset,
                        help='data file')
    parser.add_argument('--seed', type=int, default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz', type=int, default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch', type=int, default=1000,
                        help='number of training epochs')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='batch size')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument('--min-lr', type=float, default=5e-5,
                        help='min learning rate for LR scheduler. '
                             '-1 to disable annealing')
    parser.add_argument('--plot-interval', type=int, default=10,
                        help='plot interval. 0 to disable plotting.')
    parser.add_argument('--save-interval', type=int, default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix', default='pvae',
                        help='prefix of output directory')
    parser.add_argument('--comp', type=int, default=5,
                        help='continuous convolution kernel size')
    parser.add_argument('--sigma', type=float, default=.2,
                        help='standard deviation for Gaussian likelihood')
    parser.add_argument('--overlap', type=float, default=.5,
                        help='kernel overlap')
    # squash is off when rescale is off
    parser.add_argument('--squash', dest='squash', action='store_const',
                        const=True, default=True,
                        help='bound the generated time series value '
                             'using tanh')
    parser.add_argument('--no-squash', dest='squash', action='store_const',
                        const=False)

    # rescale to [-1, 1]
    parser.add_argument('--rescale', dest='rescale', action='store_const',
                        const=True, default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale', dest='rescale', action='store_const',
                        const=False)

    args = parser.parse_args()

    batch_size = args.batch_size
    nz = args.nz

    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval

    try:
        npz = np.load(args.data)
        train_data = npz['data']
        train_time = npz['time']
        train_mask = npz['mask']
    except FileNotFoundError:
        if args.data != default_dataset:
            raise
        # Generate the default toy dataset from scratch
        train_data, train_time, train_mask, _, _ = gen_data(
            n_samples=10000, seq_len=200, max_time=1, poisson_rate=50,
            obs_span_rate=.25, save_file=default_dataset)

    _, in_channels, seq_len = train_data.shape
    train_time *= train_mask

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Scale time
    max_time = 5
    train_time *= max_time

    squash = None
    rescaler = None
    if args.rescale:
        rescaler = Rescaler(train_data)
        train_data = rescaler.rescale(train_data)
        if args.squash:
            squash = torch.tanh

    out_channels = 64
    cconv_ref = 98

    train_dataset = TimeSeries(
        train_data, train_time, train_mask, label=None, max_time=max_time,
        cconv_ref=cconv_ref, overlap_rate=args.overlap, device=device)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        drop_last=True, collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    test_batch_size = 64
    test_loader = DataLoader(train_dataset, batch_size=test_batch_size,
                             collate_fn=train_dataset.collate_fn)

    grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
    decoder = Decoder(grid_decoder, max_time=max_time).to(device)

    cconv = ContinuousConv1D(
        in_channels, out_channels, max_time, cconv_ref,
        overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device)

    encoder = Encoder(nz, cconv).to(device)

    pvae = PVAE(encoder, decoder, sigma=args.sigma).to(device)

    optimizer = optim.Adam(pvae.parameters(), lr=args.lr)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)

    path = '{}_{}_{}'.format(
        args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
        '_'.join([f'lr_{args.lr:g}']))

    output_dir = Path('results') / 'toy-pvae' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)

    tracker = Tracker(log_dir, n_train_batch)
    visualizer = Visualizer(encoder, decoder, test_batch_size, max_time,
                            test_loader, rescaler, output_dir, device)
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        for val, idx, mask, _, cconv_graph in train_loader:
            optimizer.zero_grad()
            loss = pvae(val, idx, mask, cconv_graph)
            loss.backward()
            optimizer.step()
            loss_breakdown['loss'] += loss.item()

        if scheduler:
            scheduler.step()

        cur_time = time.time()
        tracker.log(
            epoch, loss_breakdown, cur_time - epoch_start, cur_time - start)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            visualizer.plot(epoch)

        model_dict = {
            'pvae': pvae.state_dict(),
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)