Example #1
0
 def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
     super().__init__()
     nx = config.n_embd
     self.c_fc = Conv1D(n_state, nx)
     self.c_proj = Conv1D(nx, n_state)
     self.act = gelu
     self.dropout = nn.Dropout(config.resid_pdrop)
 def __init__(self, n_state, n_embd):  # in MLP: n_state=3072 (4 * n_embd)
     super().__init__()
     nx = n_embd
     self.c_fc = Conv1D(n_state, nx)
     self.c_proj = Conv1D(nx, n_state)
     self.act = gelu
     self.dropout = nn.Dropout(0)
    def __init__(self, nx, n_ctx, n_head, scale=False):
        super().__init__()
        self.output_attentions = False

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % n_head == 0
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = n_head
        self.split_size = n_state
        self.scale = scale

        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
        self.attn_dropout = nn.Dropout(0)
        self.resid_dropout = nn.Dropout(0)
        self.pruned_heads = set()
Example #4
0
    def __init__(self, nx, n_ctx, config, scale=False):
        super().__init__()
        self.output_attentions = config.output_attentions

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
        self.register_buffer(
            "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale

        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.pruned_heads = set()
Example #5
0
def prune_conv1d_layer(layer, index, dim=1):
    """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
        A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if dim == 0:
        b = layer.bias.clone().detach()
    else:
        b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    new_layer.bias.requires_grad = False
    new_layer.bias.copy_(b.contiguous())
    new_layer.bias.requires_grad = True
    return new_layer
Example #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('data-dir', type=str, default='data')
    parser.add_argument('out-dir', type=str, default='out')
    parser.add_argument('--experiment', type=str, default='train')

    # Default parameters are set based on single GPU training
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument("--seed", type=int, default=0)

    parser.add_argument('--modalities',
                        type=str,
                        default='g2t',
                        choices=['t2t', 'g2g', 'g2t', 't2g'])
    parser.add_argument('--data_type',
                        type=str,
                        default='t1',
                        choices=['t' + str(i) for i in range(9)],
                        help="t: type")
    parser.add_argument('--model_type',
                        type=str,
                        default='mmvae',
                        choices=['vae', 'mmvae'])
    parser.add_argument('--iterations', type=int, default=30_950)
    parser.add_argument('--dataset',
                        type=str,
                        default='g2t',
                        help="Dataset to use for training")
    parser.add_argument(
        '--tuning_all',
        type=int,
        default=5000,
        help='When to start tuning the pretrained model as well')
    parser.add_argument(
        '--warmup',
        type=int,
        default=10000,
        help=
        "Amount of iterations to warmup, then decay. (-1 for no warmup and decay)"
    )

    parser.add_argument('--batch-sizes',
                        nargs='+',
                        type=int,
                        default=[8],
                        help='batch size per GPU. Lists the schedule.')
    parser.add_argument('--seq-lens',
                        nargs='+',
                        type=int,
                        default=[128],
                        help='seq length per sample. Lists the schedule.')
    parser.add_argument(
        '--switch-time',
        type=float,
        default=0,
        help="Percentage of iterations to spend on short sequence training.")
    parser.add_argument(
        '--load', type=str,
        help='path to load model from')  # , default='out/test/'
    parser.add_argument('--workers',
                        default=1,
                        type=int,
                        metavar='N',
                        help='number of data loading workers')

    parser.add_argument('--print_every', type=int, default=1000)
    parser.add_argument('--val_every', type=int, default=5000)

    # KL cost annealing, increase beta from beta_0 to 1 in beta_warmup steps
    parser.add_argument('--beta_cycles', type=int, default=1)
    parser.add_argument('--beta_ratio', type=float, default=0)
    parser.add_argument('--free_bits', type=float, default=0)
    parser.add_argument('--mu_gamma', type=int, default=0)
    parser.add_argument('--word_dropout', type=float, default=1)

    parser.add_argument('--add_input', action="store_true")
    parser.add_argument('--add_attn', action="store_true")
    parser.add_argument('--add_softmax', action="store_true")
    parser.add_argument('--attn_proj_vary', action="store_true")

    # use GPU
    parser.add_argument('--gpu', default=0, type=int)
    parser.add_argument('--no_gpu', action="store_true")

    args = parser.parse_args()  # wi.12.proj_vary_beta_cvae
    print(args)

    if args.model_type == 'mmvae':
        assert args.modalities == 'g2t', 'Modallity has to be g2t when using MMVAE'

    # GPU
    if not torch.cuda.is_available(): args.no_gpu = True
    gpu = not args.no_gpu
    if gpu:
        print("There are ", torch.cuda.device_count(), " available GPUs!")
        # print('Setting GPUs {}'.format(args.device))
        print('Using GPU devices {}'.format(devices))
        torch.cuda.set_device(args.gpu)
        print('Current single GPU: {}'.format(torch.cuda.current_device()))
    device = torch.device(args.gpu if gpu else "cpu")

    # randomness
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    if gpu:
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # logging
    save_folder = os.path.join(args.out_dir, args.experiment)
    os.makedirs(save_folder, exist_ok=True)
    t_writer = SummaryWriter(os.path.join(save_folder, 'train'), flush_secs=5)
    v_writer = SummaryWriter(os.path.join(save_folder, 'val'), flush_secs=5)
    importlib.reload(logging)
    logging.basicConfig(filename=os.path.join(save_folder, 'train.log'),
                        level=logging.INFO,
                        format='%(asctime)s--- %(message)s')
    logging.info(
        '\n*******************************************************************************\n'
    )
    logging.info("the configuration:")
    logging.info(str(args).replace(',', '\n'))

    print('Loading models...')
    cache_dir = os.path.join(args.out_dir, 'model_cache')
    os.makedirs(cache_dir, exist_ok=True)
    # Load pre-trained teacher tokenizer (vocabulary)
    tokenizer = GPT2Tokenizer.from_pretrained(
        'gpt2',
        cache_dir=cache_dir,
        bos_token='<|startoftext|>',
        eos_token='<|endoftext|>',
        pad_token='<pad>',
        cls_token='<cls>',
    )

    # Hack to allow tokenizing longer sequences.
    tokenizer.max_len = int(1e12)
    gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=cache_dir)
    print('gpt2_params:', num_params(gpt2_model))  # gpt2: 124439808
    config = GPT2Config()

    if 'g' in args.modalities:
        tokenizer.add_tokens(['<S>', '</S>', '<R>', '</R>', '<O>', '</O>'])
        gpt2_model.resize_token_embeddings(len(tokenizer))

    if args.model_type == 'vae':
        model = VAE(config,
                    add_input=args.add_input,
                    add_attn=args.add_attn,
                    add_softmax=args.add_softmax,
                    attn_proj_vary=args.attn_proj_vary)
    if args.model_type == 'mmvae':
        model = MMVAE(config,
                      add_input=args.add_input,
                      add_attn=args.add_attn,
                      add_softmax=args.add_softmax,
                      attn_proj_vary=args.attn_proj_vary)

    init_para_frompretrained(model.decoder,
                             gpt2_model.transformer,
                             share_para=True)
    init_para_frompretrained(model.encoder,
                             gpt2_model.transformer,
                             share_para=False)
    model.lm_head.weight = gpt2_model.lm_head.weight
    if model.add_softmax:
        model.lm_head_rep = Conv1D(*gpt2_model.lm_head.weight.size())

    print('VAE_params:', num_params(model))  # 286694400
    if args.load:
        print('Loading model weights...')
        state = torch.load(os.path.join(
            args.load,
            'model_latest.pt'))  # , map_location='cpu' model_latest.pt
        if 'module' in list(state.keys(
        ))[0]:  # model_path is data parallel model with attr 'module'
            state_copy = copy.copy(state)
            keys = state_copy.keys()
            for k in keys:
                state[k.replace('module.', '')] = state.pop(k)
        model.load_state_dict(state)
        gc.collect()
    print('Done.')

    # fix pre-trained parameters before certain iterations
    tuning_all_after_iters = args.tuning_all
    tuning_all = False
    for name, parameter in model.named_parameters():
        # print((name, parameter.requires_grad))
        new_pars = [
            'c_z', 'attention_weights', 'mean', 'logvar', 'input_proj',
            'attn_proj', 'Nu_fc1', 'Nu_fc2', 'lm_head_rep'
        ]

        if not any([True if n in name else False for n in new_pars]):
            parameter.requires_grad = False

    print('Setup data...')
    # Batch and sequence length schedule
    # assert len(args.batch_sizes) == len(args.seq_lens)
    batch_schedule = list(
        zip(map(int, args.batch_sizes), map(int, args.seq_lens)))
    # assert len(batch_schedule) <= 2, 'Currently not supporting multiple schedule'
    cur_b_schedule = len(batch_schedule) - 1 if args.switch_time == 0 else 0
    print('Batch schedule', batch_schedule)
    train_loader, val_loader, test_loader = prepare_dataset(
        args.data_dir,
        args.dataset,
        tokenizer,
        batch_schedule[cur_b_schedule][0],
        batch_schedule[cur_b_schedule][1],
        batch_schedule[-1][0],
        batch_schedule[-1][1],
        batch_schedule[-1][0],
        batch_schedule[-1][1],
        make_test=True,
        modalities=args.modalities,
        num_workers=args.workers,
        data_type=args.data_type)
    print('Done.')

    print('Wrapping models and optimizers...')
    # Apply linear scaling rule to increase batch size for short sequence training.
    lr_schedule = switch_schedule(
        linear_schedule(args),
        batch_schedule[cur_b_schedule][0] / batch_schedule[-1][0],
        int(args.iterations * args.switch_time))
    model = model.to(device)
    model.train()

    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.9)

    loss_fn = nn.CrossEntropyLoss(reduction='none')
    print('Done.')

    print('Begin training iterations')
    logging.info("Begin training iterations")
    logging.info("Total iteration: %d" % args.iterations)
    e = 0  # number of epoch
    num_iters = 0
    optimizer.zero_grad()
    printing = False
    prints = 0

    # val_or_test(device, VAE, tokenizer, val_loader, loss_fn, v_writer, args,
    #            eos_token='<|endoftext|>', num_iters=num_iters, mode='Validation')
    # torch.save(VAE.state_dict(), os.path.join(save_folder, 'model_' + '{:07d}'.format(num_iters) + '.pt'))

    while num_iters < args.iterations:
        # Run epoch
        st = time.time()

        # Training
        print('Training loop. Batches:', len(train_loader))
        logging.info(
            '\n----------------------------------------------------------------------'
        )
        logging.info("Training loop.       Batches: %d" % len(train_loader))

        with tqdm(total=len(train_loader)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(train_loader):

                if args.beta_ratio != 0:
                    if num_iters % (args.iterations / args.beta_cycles) <= (
                            args.iterations /
                            args.beta_cycles) * args.beta_ratio:
                        beta = (1 / ((args.iterations/args.beta_cycles) * args.beta_ratio)) \
                               * (num_iters % (args.iterations/args.beta_cycles))
                else:
                    beta = 1

                if not tuning_all and num_iters >= tuning_all_after_iters:
                    for name, parameter in model.named_parameters():
                        parameter.requires_grad = True
                    tuning_all = True

                output, sentence = train_step(device, model, optimizer, x_mask,
                                              x_tokens, y_mask, y_tokens,
                                              input_tokens, target_tokens,
                                              mask, loss_fn, beta,
                                              args.free_bits, args.mu_gamma,
                                              args.word_dropout)
                loss, ce_loss, kl_loss = output[-1]

                lr = scheduler.get_last_lr()[0]
                # Log to Tensorboard
                t_writer.add_scalar('loss', loss, num_iters)
                t_writer.add_scalar('ppl', math.exp(min(ce_loss, 10)),
                                    num_iters)
                t_writer.add_scalar('lr', lr, num_iters)
                t_writer.add_scalar('iter_time', time.time() - st, num_iters)
                t_writer.add_scalar('kl', kl_loss, num_iters)
                t_writer.add_scalar('beta', beta, num_iters)

                st = time.time()
                end = num_iters >= args.iterations

                if args.warmup != -1:
                    scheduler.step()

                if end: break
                num_iters += 1
                pbar.update(1)

                if num_iters % args.print_every == 0 or printing == True:
                    print(f'num_iters: {num_iters}')
                    print(f'Loss: {loss}, Beta: {beta}')
                    for i in range(x_tokens.shape[0]):
                        print(f'Target: {tokenizer.decode(x_tokens[i])}')
                        print(
                            f'Reconstruction: {tokenizer.decode(sentence[i])}')
                        print('')
                    prints += 1
                    printing = True
                    if prints == 3:
                        prints = 0
                        printing = False

                if num_iters % args.val_every == 0:
                    val_or_test(device,
                                model,
                                tokenizer,
                                val_loader,
                                loss_fn,
                                v_writer,
                                args,
                                eos_token='<|endoftext|>',
                                num_iters=num_iters)
                    print('Saving model...')
                    logging.info("Iteration completed: %d, remained %d" %
                                 (num_iters, args.iterations - num_iters))
                    logging.info("Saving model...")
                    logging.info(
                        '\n------------------------------------------------------'
                    )
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            save_folder,
                            'model_' + '{:07d}'.format(num_iters) + '.pt'))

                if args.switch_time > 0 and num_iters == int(
                        args.iterations * args.switch_time):
                    print('Switch to long sequence training')
                    logging.info("Switch to long sequence training")
                    cur_b_schedule += 1
                    train_loader, val_loader, test_loader = prepare_dataset(
                        args.data_dir,
                        args.dataset,
                        tokenizer,
                        batch_schedule[cur_b_schedule][0],
                        batch_schedule[cur_b_schedule][1],
                        batch_schedule[-1][0],
                        batch_schedule[-1][1],
                        batch_schedule[-1][0],
                        batch_schedule[-1][1],
                        make_test=True,
                        num_workers=args.workers,
                        data_type=args.data_type)
        if not end:
            e += 1
            logging.info("Training loop. The ith epoch completed: %d" % e)

    torch.save(model.state_dict(), os.path.join(save_folder,
                                                'model_latest.pt'))
    print('Training complete.')
    logging.info("Training complete.")
Example #7
0
print('We have added', num_added_toks, 'special tokens')
# Notice: resize_token_embeddings expect to receive the full size of the new vocab
gpt2_model.resize_token_embeddings(len(tokenizer))
assert tokenizer.pad_token == '<|startoftext|>'

VAE = VAEModel(config,
               add_input=add_input,
               add_attn=add_attn,
               add_softmax=add_softmax)
init_para_frompretrained(VAE.transformer,
                         gpt2_model.transformer,
                         share_para=True)
init_para_frompretrained(VAE.encoder, gpt2_model.transformer, share_para=False)
VAE.lm_head.weight = gpt2_model.lm_head.weight
if VAE.add_softmax:
    VAE.lm_head_rep = Conv1D(*gpt2_model.lm_head.weight.size())
print('VAE_params:', num_params(VAE))  # 286694400
print('Done.')

print('Loading model weights...')
state = torch.load(os.path.join(args.load, 'model_latest.pt'),
                   map_location='cpu')
if 'module' in list(state.keys(
))[0]:  # model_path is data parallel model with attr 'module'
    state_copy = copy.copy(state)
    keys = state_copy.keys()
    for k in keys:
        state[k.replace('module.', '')] = state.pop(k)
VAE.load_state_dict(state)
VAE.eval()
VAE = VAE.to(device)
Example #8
0
def main_worker(gpu, ngpus_per_node, args):
    if args.model_type == 'cvae':
        args.learn_prior = True
    else:
        args.learn_prior = False

    # GPU
    args.gpu = gpu
    print("There are ", torch.cuda.device_count(), " available GPUs!")
    # print('Setting GPUs {}'.format(args.device))
    print('Using GPU devices {}'.format(devices))
    device = torch.device('cuda', args.gpu)
    torch.cuda.set_device(device)
    print('Current single GPU: {}'.format(torch.cuda.current_device()))

    # randomness
    np.random.seed(args.seed)
    prng = np.random.RandomState()
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # For multiprocessing distributed training, rank needs to be the global rank among all the processes
    args.rank = args.rank * ngpus_per_node + gpu
    print('Setting rank', args.rank)
    recon_attempt = 1
    connected = False
    if args.rank != 0:
        # Stall to have rank 0 node go first
        time.sleep(3)
    while not connected:
        try:
            dist.init_process_group(backend=args.dist_backend,
                                    init_method=args.dist_url,
                                    world_size=args.world_size,
                                    rank=args.rank)
            connected = True
            print('Established connection. Rank:', args.rank)
        except Exception as e:
            # Sometimes the head node launches after the worker, which would cause an issue
            print('Failed to init process group. Retrying...', recon_attempt,
                  e)
            recon_attempt += 1
            time.sleep(10)

    # logging
    if args.rank == 0:
        save_folder = os.path.join(args.out_dir, args.experiment)
        os.makedirs(save_folder, exist_ok=True)
        t_writer = SummaryWriter(os.path.join(save_folder, 'train'),
                                 flush_secs=5)
        v_writer = SummaryWriter(os.path.join(save_folder, 'val'),
                                 flush_secs=5)
        importlib.reload(logging)
        logging.basicConfig(filename=os.path.join(save_folder, 'train.log'),
                            level=logging.INFO,
                            format='%(asctime)s--- %(message)s')
        logging.info(
            '\n*******************************************************************************\n'
        )
        logging.info("the configuration:")
        logging.info(str(args).replace(',', '\n'))

    print('Loading models...')
    cache_dir = os.path.join(args.out_dir, 'model_cache')
    os.makedirs(cache_dir, exist_ok=True)
    # Load pre-trained teacher tokenizer (vocabulary)
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
    # Hack to allow tokenizing longer sequences.
    tokenizer.max_len = int(1e12)
    gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=cache_dir)
    print('gpt2_params:', num_params(gpt2_model))  # gpt2: 124439808
    config = GPT2Config()

    # add special tokens
    # special_tokens_dict = {
    #     'pad_token': '<|startoftext|>',
    #     'cls_token': '<|startofcond|>',
    #     'sep_token': '<|sepofcond|>',
    #     'mask_token': '<|endofcond|>'
    # }
    # num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    # print('We have added', num_added_toks, 'special tokens')
    # # Notice: resize_token_embeddings expect to receive the full size of the new vocab
    # gpt2_model.resize_token_embeddings(len(tokenizer))
    # assert tokenizer.pad_token == '<|startoftext|>'

    VAE = VAEModel(config,
                   add_input=args.add_input,
                   add_attn=args.add_attn,
                   add_softmax=args.add_softmax,
                   attn_proj_vary=args.attn_proj_vary,
                   learn_prior=args.learn_prior)
    init_para_frompretrained(VAE.transformer,
                             gpt2_model.transformer,
                             share_para=True)
    init_para_frompretrained(VAE.encoder,
                             gpt2_model.transformer,
                             share_para=False)
    if args.learn_prior:
        init_para_frompretrained(VAE.encoder_prior,
                                 VAE.encoder,
                                 share_para=True)
        VAE.encoder_prior.averageSelfAttention.attention_weights = VAE.encoder.averageSelfAttention.attention_weights
    VAE.lm_head.weight = gpt2_model.lm_head.weight
    if VAE.add_softmax:
        VAE.lm_head_rep = Conv1D(*gpt2_model.lm_head.weight.size())
        # VAE.lm_head_rep = LM_head_rep(*gpt2_model.lm_head.weight.size()[::-1])
    print('VAE_params:', num_params(VAE))  # 286694400
    if args.load:
        print('Loading model weights...')
        state = torch.load(os.path.join(
            args.load, 'model_latest.pt'))  # , map_location='cpu'
        if 'module' in list(state.keys(
        ))[0]:  # model_path is data parallel model with attr 'module'
            state_copy = copy.copy(state)
            keys = state_copy.keys()
            for k in keys:
                state[k.replace('module.', '')] = state.pop(k)
        VAE.load_state_dict(state)
        gc.collect()
    print('Done.')

    # fix pre-trained parameters before certain iterations
    tuning_all_after_iters = 10000
    tuning_all = False
    for name, parameter in VAE.named_parameters():
        # print((name, parameter.requires_grad))
        new_pars = [
            'c_z', 'attention_weights', 'mean', 'logvar', 'input_proj',
            'attn_proj', 'Nu_fc1', 'Nu_fc2', 'lm_head_rep'
        ]

        if not any([True if n in name else False for n in new_pars]):
            parameter.requires_grad = False

    print('Setup data...')
    # Batch and sequence length schedule
    assert len(args.batch_sizes) == len(args.seq_lens)
    batch_schedule = list(
        zip(map(int, args.batch_sizes), map(int, args.seq_lens)))
    assert len(
        batch_schedule) <= 2, 'Currently not supporting multiple schedule'
    cur_b_schedule = len(batch_schedule) - 1 if args.switch_time == 0 else 0
    print('Batch schedule', batch_schedule)
    train_loader, val_loader, test_loader = prepare_dataset(
        args.data_dir,
        args.dataset,
        tokenizer,
        batch_schedule[cur_b_schedule][0],
        batch_schedule[cur_b_schedule][1],
        batch_schedule[-1][0],
        batch_schedule[-1][1],
        batch_schedule[-1][0],
        batch_schedule[-1][1],
        make_test=True,
        num_workers=args.workers,
        data_type=args.data_type)
    print('Done.')

    ###
    val_loader = test_loader
    ###

    print('Wrapping models and optimizers...')
    # Apply linear scaling rule to increase batch size for short sequence training.
    lr_schedule = switch_schedule(
        linear_schedule(args),
        batch_schedule[cur_b_schedule][0] / batch_schedule[-1][0],
        int(args.iterations * args.switch_time))
    VAE = VAE.to(device)
    VAE = VAE.train()

    optimizer = AdamW(VAE.parameters(), lr=args.lr, correct_bias=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
    VAE, optimizer = amp.initialize(VAE,
                                    optimizer,
                                    opt_level=args.fp16_opt_level)
    loss_model = DDP(VAE)  # , delay_allreduce=True

    loss_fn = nn.CrossEntropyLoss(reduction='none')
    print('Done.')

    print('Begin training iterations')
    logging.info("Begin training iterations")
    max_val_batches = 20000  # max num. of val batches
    logging.info("Total iteration: %d" % args.iterations)
    e = 0  # number of epoch
    num_iters = 0
    optimizer.zero_grad()
    beta = args.beta_0
    endoftext = tokenizer.convert_tokens_to_ids("<|endoftext|>")

    def val_step(val_loader):
        VAE.eval()

        n_words_bpe = 0
        n_words = 0
        logp_sum = 0.0
        kl_loss_sum = 0.0

        logging.info("Validation loop.         Batches: %d" % len(val_loader))
        logging.info("Validation loop. max_val_batches: %d" % max_val_batches)

        # val_iter = iter(val_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(val_iter)
        with tqdm(total=min(len(val_loader), max_val_batches)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(val_loader):
                with torch.no_grad():
                    if args.model_type == 'cvae':
                        loss, ce_loss, kl_loss = compute_loss(
                            device, VAE, x_mask, x_tokens, y_mask, y_tokens,
                            input_tokens, target_tokens, mask, loss_fn, 1.0)
                    else:
                        loss, ce_loss, kl_loss = compute_loss_ae(
                            device, VAE, x_mask, x_tokens, y_mask, y_tokens,
                            input_tokens, target_tokens, mask, loss_fn, 1.0)

                if len(target_tokens.size()) == 1:
                    target_tokens = target_tokens.unsqueeze(0)
                n, l = target_tokens.size()

                text = target_tokens[0, :].tolist()
                logprob = ce_loss.tolist()
                assert len(text) == len(logprob)

                # only for story
                idx = text.index(endoftext)
                text = text[idx + 1:]
                logprob = logprob[idx + 1:]

                if endoftext in text:
                    idx = text.index(endoftext)
                    text = text[:idx]
                    logprob = logprob[:idx]

                logp_sum += sum(logprob)

                n_words_bpe += len(text)

                story = [
                    tokenizer.decode(target_tokens[i, :]) for i in range(n)
                ]
                story = [
                    s[s.find("<|endoftext|>") + len("<|endoftext|>"):]
                    for s in story
                ]
                story = [
                    s[:s.find("<|endoftext|>") +
                      len("<|endoftext|>")] if "<|endoftext|>" in s else s
                    for s in story
                ]
                words = sum([
                    len([
                        t for t in re.split(
                            '("|\'|!|\?|\.|,|:| |\n|’|“|”|;|\(|\)|`)', s)
                        if t != ' ' and t != ''
                    ]) for s in story
                ])
                n_words += words

                kl_loss_sum += kl_loss.item()

                if i > max_val_batches:
                    break
                pbar.update(1)

        loss_bpe = logp_sum / n_words_bpe
        ppl_bpe = round(math.exp(min(logp_sum / n_words_bpe, 100)), 3)
        ppl_word = round(math.exp(min(logp_sum / n_words, 100)), 3)
        kl = kl_loss_sum / len(val_loader)

        v_writer.add_scalar('loss', loss_bpe, num_iters)
        v_writer.add_scalar('ppl_bpe', ppl_bpe, num_iters)
        v_writer.add_scalar('ppl_word', ppl_word, num_iters)
        v_writer.add_scalar('kl', kl, num_iters)
        logging.info('val loss    : %.4f' % loss_bpe)
        logging.info('val ppl_bpe : %.4f' % ppl_bpe)
        logging.info('val ppl_word: %.4f' % ppl_word)
        logging.info('val   kl    : %.4f' % kl)

        VAE.train()

    def test_plot(test_loader, num_iters):
        VAE.eval()

        # get embedding
        X_emb = None
        y = None

        # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter)
        with tqdm(total=len(test_loader)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(test_loader):
                y_mask = y_mask.to(device)
                y_tokens = y_tokens.to(device)
                x_mask = x_mask.to(device)
                x_tokens = x_tokens.to(device)
                with torch.no_grad():
                    if args.model_type == 'cvae':
                        latent_mean, _ = VAE.encoder_prior(
                            input_ids=x_tokens, attention_mask=x_mask)[:2]
                    else:
                        latent_mean, _ = VAE.encoder(input_ids=x_tokens,
                                                     attention_mask=x_mask)[:2]

                if args.dataset == 'ax' or args.dataset == 'yp':
                    label = [
                        tokenizer.decode(l)[:2] for l in x_tokens.tolist()
                    ]
                elif args.dataset == 'wp':
                    label = []
                    prompts = [
                        tokenizer.decode(l)[:6].lower()
                        for l in x_tokens.tolist()
                    ]
                    for prom in prompts:
                        if prom[0] in ['[', '('] and prom[5] in [']', ')']:
                            label.append(prom[2:4])
                        else:
                            label.append(None)
                elif args.dataset == 'wi':
                    # 0. TV, play, miniseries, telenovela; 1.film; 2. music; 3. manga, comic, 4. book, novel, story 5. game
                    label = []
                    prompts = [tokenizer.decode(l) for l in x_tokens.tolist()]
                    for prom in prompts:
                        if 'TV' in prom or 'play' in prom or 'miniseries' in prom or 'telenovela' in prom:
                            label.append(0)
                        elif 'film' in prom:
                            label.append(1)
                        elif 'music' in prom:
                            label.append(2)
                        elif 'manga' in prom or 'comic' in prom:
                            label.append(3)
                        elif 'book' in prom or 'novel' in prom or 'story' in prom:
                            label.append(4)
                        elif 'game' in prom:
                            label.append(5)
                        else:
                            label.append(None)
                else:
                    raise Exception

                if i == 0:
                    X_emb = latent_mean.data
                    y = label
                else:
                    X_emb = torch.cat((X_emb, latent_mean.data), dim=0)
                    y.extend(label)
                pbar.update(1)
        X_emb = X_emb.cpu().numpy()

        try:
            if args.dataset == 'yp':
                y = ['0' if l in ['0', '1'] else l for l in y]
                y = ['4' if l in ['3', '4'] else l for l in y]
                X_emb = X_emb[[l != '2' for l in y], :]
                y = [l for l in y if l != '2']

            if args.dataset == 'wp':
                topics = [['wp', 'sp', 'tt'], ['eu'], ['cw'], ['pm'],
                          ['mp', 'ip'], ['pi', 'cc'], ['ot'], ['rf']]
                match = [[True if l in t else False for t in topics]
                         for l in y]
                y = [m.index(True) if True in m else None for m in match]
                X_emb = X_emb[[l is not None for l in y], :]
                y = [l for l in y if l is not None]

            if args.dataset == 'wi':
                X_emb = X_emb[[l is not None for l in y], :]
                y = [l for l in y if l is not None]

            # to 2D
            # X_emb_2d = TSNE(n_components=2, init='pca', verbose=1).fit_transform(X_emb)
            X_emb_2d = TSNE(n_components=2, verbose=1,
                            perplexity=40).fit_transform(X_emb)

            def remove_outliers(data, r=2.0):
                outliers_data = abs(
                    data - np.mean(data, axis=0)) >= r * np.std(data, axis=0)
                outliers = np.any(outliers_data, axis=1)
                keep = np.logical_not(outliers)
                return outliers, keep

            outliers, keep = remove_outliers(X_emb_2d)
            X_emb_2d = X_emb_2d[keep, :]
            y = [l for l, k in zip(y, keep.tolist()) if k]

            # plot
            fig = plt.figure(figsize=(4, 4))
            ax = fig.add_axes([0, 0, 1, 1])
            cc = ['r', 'b', 'g', 'y', 'k', 'c', 'm', 'tab:blue']
            for i, l in enumerate(sorted(set(y))):
                idx = [yl == l for yl in y]
                plt.scatter(X_emb_2d[idx, 0],
                            X_emb_2d[idx, 1],
                            c=cc[i],
                            s=10,
                            edgecolor='none',
                            alpha=0.5)
            ax.axis('off')  # adding it will get no axis
            plt.savefig(
                os.path.join(save_folder,
                             'tSNE_' + '{:07d}'.format(num_iters) + '.png'))
            plt.close(fig)
        except:
            pass

        VAE.train()

    def generate(test_loader, num_iters):
        VAE.eval()

        n_samples = 0
        bleu4_sum = 0.0
        rouge_scores_values_sum = [0.0] * 9

        args.nsamples = 1
        args.batch_size = 1
        args.temperature = 0.95
        args.top_k = 100
        args.top_p = 0.95
        model_type = args.model_type

        # write samples to file
        samples_file = open(os.path.join(
            save_folder, 'generate-' + '%07d' % num_iters + '.txt'),
                            'w',
                            encoding='utf8')

        # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter)
        with tqdm(total=len(test_loader)) as pbar:
            for i_test, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                         target_tokens, mask) in enumerate(test_loader):

                if i_test >= 10: break

                length = -1
                if length == -1:
                    length = VAE.config.n_ctx - x_tokens.size(1) - 1
                elif length > VAE.config.n_ctx - x_tokens.size(1) - 1:
                    raise ValueError(
                        "Can't get samples longer than window size: %s" %
                        VAE.config.n_ctx)

                eff_samples = []
                n, l = target_tokens.size()
                storys = [
                    tokenizer.decode(target_tokens[i, :]) for i in range(n)
                ]
                storys = [
                    s[s.find("<|endoftext|>") + len("<|endoftext|>"):]
                    for s in storys
                ]
                storys_str = [
                    s[:s.find("<|endoftext|>") +
                      len("<|endoftext|>")] if "<|endoftext|>" in s else s
                    for s in storys
                ]

                for _ in range(args.nsamples // args.batch_size):
                    # model, batch_size, temperature, top_k, top_p, eos_token, sample = VAE, args.batch_size, args.temperature, args.top_k, args.top_p, tokenizer.encoder['<|endoftext|>'], True
                    out, _ = sample_sequence(
                        model=VAE,
                        tokenizer=tokenizer,
                        length=length,
                        batch_size=args.batch_size,
                        x_mask=x_mask,
                        x_tokens=x_tokens,
                        y_mask=y_mask,
                        y_tokens=y_tokens,
                        temperature=args.temperature,
                        top_k=args.top_k,
                        top_p=args.top_p,
                        device=device,
                        eos_token=tokenizer.encoder['<|endoftext|>'],
                        model_type=model_type)
                    out = out.tolist()

                    # extract story, check metrics
                    for i in range(len(out)):
                        text = out[i]
                        text = text[text.index(endoftext) + 1:]

                        if endoftext in text:
                            idx = text.index(endoftext)
                            text = text[:idx]

                        text = tokenizer.decode(text).strip()

                        # score for one long text, higher than 0.075 usually means repetition
                        # rep_score = repeat_score(text.split(), ngram=[3, 4, 5, 6, 7, 8])
                        # if rep_score > 0.075:
                        #     # print(rep_score)
                        #     continue

                        try:
                            # check bleu
                            bleu4 = sentence_bleu(
                                [storys_str[i].split()],
                                text,
                                smoothing_function=SmoothingFunction().method7)

                            # check rouge
                            rouge = Rouge()
                            rouge_scores = rouge.get_scores(
                                text, storys_str[i])
                            rouge_scores_values = [
                                v for k in rouge_scores[0].keys()
                                for v in rouge_scores[0][k].values()
                            ]

                            bleu4_sum += bleu4
                            rouge_scores_values_sum = [
                                v1 + v2
                                for v1, v2 in zip(rouge_scores_values_sum,
                                                  rouge_scores_values)
                            ]
                            n_samples += 1
                        except:
                            bleu4 = 0.0
                            rouge_scores = [{
                                'rouge-1': {
                                    'f': 0.0,
                                    'p': 0.0,
                                    'r': 0.0
                                },
                                'rouge-2': {
                                    'f': 0.0,
                                    'p': 0.0,
                                    'r': 0.0
                                },
                                'rouge-l': {
                                    'f': 0.0,
                                    'p': 0.0,
                                    'r': 0.0
                                }
                            }]

                        eff_samples.append((text, bleu4, rouge_scores))

                    pbar.update(1)

                for i in range(len(eff_samples)):
                    samples_file.write("=" * 50 + " SAMPLE " + str(i_test) +
                                       " " + "=" * 50)
                    samples_file.write('\n' * 2)

                    samples_file.write("=" * 40 + " Outlines  " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(
                        tokenizer.decode(
                            x_tokens[i, :][x_mask[i, :] == 1].tolist()))
                    samples_file.write('\n' * 2)
                    samples_file.write("=" * 40 + " Story " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(storys_str[i])
                    samples_file.write('\n' * 2)

                    samples_file.write("=" * 40 + " Generated " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(eff_samples[i][0])
                    samples_file.write('\n' * 4)
                    samples_file.flush()

        print('Test complete with %05d samples.' % n_samples)
        logging.info("Test complete with %05d samples.", n_samples)
        logging.info("Iteration completed: %d" % num_iters)

        bleu4 = round(bleu4_sum / n_samples, 3)
        rouge_scores_values = [
            round(r / n_samples, 3) for r in rouge_scores_values_sum
        ]
        print(' bleu-4:', bleu4)
        print(' rouge :', rouge_scores_values)
        logging.info(' bleu-4: %f', bleu4)
        logging.info(' rouge : %s', str(rouge_scores_values))

        VAE.train()

    if args.rank == 0:
        test_plot(test_loader, num_iters)
        val_step(val_loader)
        generate(test_loader, num_iters)
        torch.save(
            VAE.state_dict(),
            os.path.join(save_folder,
                         'model_' + '{:07d}'.format(num_iters) + '.pt'))

    while num_iters < args.iterations:
        # Run epoch
        st = time.time()

        # Training
        print('Training loop. Batches:', len(train_loader))
        logging.info(
            '\n----------------------------------------------------------------------'
        )
        logging.info("Training loop.       Batches: %d" % len(train_loader))

        # train_iter = iter(train_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(train_iter)
        with tqdm(total=len(train_loader)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(train_loader):

                if num_iters % args.cycle >= args.cycle - args.beta_warmup - 25000:
                    beta = min(1.0,
                               beta + (1. - args.beta_0) / args.beta_warmup)

                if not tuning_all and num_iters >= tuning_all_after_iters:
                    for name, parameter in VAE.named_parameters():
                        # print((name, parameter.requires_grad))
                        parameter.requires_grad = True
                    tuning_all = True

                output = train_step(device, loss_model, optimizer, x_mask,
                                    x_tokens, y_mask, y_tokens, input_tokens,
                                    target_tokens, mask, loss_fn, beta,
                                    args.model_type)

                if args.rank == 0:
                    loss, ce_loss, kl_loss = output[-1]
                    lr = scheduler.get_last_lr()[0]
                    # Log to Tensorboard
                    t_writer.add_scalar('loss', loss, num_iters)
                    t_writer.add_scalar('ppl', math.exp(min(ce_loss, 10)),
                                        num_iters)
                    t_writer.add_scalar('lr', lr, num_iters)
                    t_writer.add_scalar('iter_time',
                                        time.time() - st, num_iters)
                    t_writer.add_scalar('kl', kl_loss, num_iters)
                    t_writer.add_scalar('beta', beta, num_iters)

                    if args.model_type == 'ae_vae_fusion':
                        loss, ce_loss, kl_loss = output[0]
                        # Log to Tensorboard
                        t_writer.add_scalar('ae_loss', loss, num_iters)
                        t_writer.add_scalar('ae_kl', kl_loss, num_iters)

                st = time.time()
                end = num_iters >= args.iterations

                if args.warmup != -1:
                    scheduler.step()

                if end: break
                num_iters += 1
                pbar.update(1)

                if num_iters % args.cycle == 0:
                    beta = args.beta_0
                    logging.info('KL annealing restart')

                if args.rank == 0 and num_iters % 10000 == 0:
                    test_plot(test_loader, num_iters)
                    val_step(val_loader)
                    generate(test_loader, num_iters)

                if args.rank == 0 and num_iters % 10000 == 0:
                    print('Saving model...')
                    logging.info("Iteration completed: %d, remained %d" %
                                 (num_iters, args.iterations - num_iters))
                    logging.info("Saving model...")
                    logging.info(
                        '\n------------------------------------------------------'
                    )
                    torch.save(
                        VAE.state_dict(),
                        os.path.join(
                            save_folder,
                            'model_' + '{:07d}'.format(num_iters) + '.pt'))

                if args.switch_time > 0 and num_iters == int(
                        args.iterations * args.switch_time):
                    print('Switch to long sequence training')
                    logging.info("Switch to long sequence training")
                    cur_b_schedule += 1
                    train_loader, val_loader, test_loader = prepare_dataset(
                        args.data_dir,
                        args.dataset,
                        tokenizer,
                        batch_schedule[cur_b_schedule][0],
                        batch_schedule[cur_b_schedule][1],
                        batch_schedule[-1][0],
                        batch_schedule[-1][1],
                        batch_schedule[-1][0],
                        batch_schedule[-1][1],
                        make_test=True,
                        num_workers=args.workers,
                        data_type=args.data_type)
        if not end:
            e += 1
            logging.info("Training loop. The ith epoch completed: %d" % e)

    if args.rank == 0:
        torch.save(VAE.state_dict(),
                   os.path.join(save_folder, 'model_latest.pt'))
    print('Training complete.')
    logging.info("Training complete.")
Example #9
0
    def load_data(self) -> (DataLoader, {}):
        label2idx = {"yes": 1, "no": 0}

        input_size = self.input_size

        class SampleDataset(Dataset):
            def __init__(self):
                self.count = 0

            def __len__(self):
                return 1000

            def __getitem__(self, idx):
                start = min(label2idx.values())
                end = max(label2idx.values()) + 1
                y = np.random.randint(start, end)

                self.count += 1
                return torch.rand(input_size), y, f"id_{self.count}"

        ds = SampleDataset()
        return DataLoader(ds, batch_size=1000), label2idx

        # Make up some training data
        training_data = [(
            "the wall street journal reported today that apple corporation made money"
            .split(), "B I I I O O O B I O O".split()),
                         ("georgia tech is a university in georgia".split(),
                          "B I O O O O B".split())]

        # char
        char_input = Input(shape=(
            None,
            char_maxlen,
        ), name="char_input")
        # not maintain input dim, directly using ASCII
        char_layer = TimeDistributed(Embedding(
            input_dim=128,
            output_dim=30,
            embeddings_initializer=RandomUniform(minval=-0.5, maxval=0.5)),
                                     trainable=True,
                                     name="char_emb")(char_input)
        char_layer = Dropout(0.5, name="char_dropout_1")(char_layer)
        char_layer = TimeDistributed(Conv1D(kernel_size=3,
                                            filters=30,
                                            padding='same',
                                            activation='tanh',
                                            strides=1),
                                     name="char_conv")(char_layer)

        char_layer = TimeDistributed(MaxPooling1D(char_maxlen),
                                     name="char_maxpool")(char_layer)
        char_layer = TimeDistributed(Flatten(),
                                     name="char_flatten")(char_layer)
        char_layer = Dropout(0.5, name="char_dropout_2")(char_layer)

        # word
        words_input = Input(shape=(None, ), dtype="int32", name="words_input")

        if word_emb is not None:  # emb_matrix.shape=vocab_size+1, emb_dim("50")
            words_layer = Embedding(input_dim=word_emb.shape[0],
                                    output_dim=word_emb.shape[1],
                                    weights=[word_emb],
                                    mask_zero=True,
                                    trainable=False,
                                    name="word_emb")(words_input)
        else:
            words_layer = Embedding(input_dim=vocab_size + 1,
                                    output_dim=emb_dim,
                                    mask_zero=True,
                                    trainable=False,
                                    name="word_emb")(words_input)

        feat_input = Input(shape=(None, feat_size),
                           dtype="float32",
                           name="feat_input")
        feat_layer = Dense(input_dim=feat_size,
                           activation="relu",
                           units=feat_size,
                           trainable=True)(feat_input)
        shared_output = concatenate([words_layer, char_layer, feat_layer],
                                    name="concat")

        shared_output = Bidirectional(LSTM(bi_rnn_units,
                                           return_sequences=True,
                                           dropout=0.5,
                                           recurrent_dropout=0.25),
                                      name="BiLSTM_1")(shared_output)
        shared_output = MultiHeadAttention(
            head_num=multi_head_num, name="Multihead_Attn")(shared_output)
        shared_output = BatchNormalization(name="BatchNorm1")(shared_output)

        intent_output = Bidirectional(LSTM(bi_rnn_units,
                                           return_sequences=False,
                                           dropout=0.5,
                                           recurrent_dropout=0.25),
                                      name="Intent_BiLSTM")(shared_output)
        intent_output = BatchNormalization(
            name="Intent_BatchNorm2")(intent_output)
        intent_output = Dense(label_size[0],
                              activation="softmax",
                              name="Intent_softmax")(intent_output)

        # had to set Embedding mask_zero=True for avoid neg loss. https://github.com/keras-team/keras-contrib/issues/278
        bi_crf = CRF(label_size[1], sparse_target=True, name="BI_CRF")
        bi_output = bi_crf(shared_output)

        concept_crf = CRF(label_size[2],
                          sparse_target=True,
                          name="Concept_CRF")
        concept_output = concept_crf(shared_output)

        emd_crf = CRF(label_size[3], sparse_target=True, name="EMD_CRF")
        emd_output = emd_crf(shared_output)

        model = Model(
            inputs=[words_input, char_input, feat_input],
            outputs=[intent_output, bi_output, concept_output, emd_output])
        adam_optimizer = Adam(beta_1=0.9, beta_2=0.98, epsilon=10e-9)

        model.compile(optimizer=adam_optimizer,
                      loss=[
                          "categorical_crossentropy",
                          crf_loss_joined(bi_crf),
                          crf_loss_joined(concept_crf),
                          crf_loss_joined(emd_crf)
                      ],
                      metrics={
                          "Intent_softmax": "accuracy",
                          "BI_CRF": crf_accuracy_joined(bi_crf),
                          "Concept_CRF": crf_accuracy_joined(concept_crf),
                          "EMD_CRF": crf_accuracy_joined(emd_crf)
                      })