示例#1
0
文件: main.py 项目: heyLinsir/cotk
def create_model(sess, data, args, embed):
    with tf.variable_scope(args.name):
        model = VAEModel(data, args, embed)
        model.print_parameters()
        latest_dir = '%s/checkpoint_latest' % args.model_dir
        best_dir = '%s/checkpoint_best' % args.model_dir
        if tf.train.get_checkpoint_state(
                latest_dir) and args.restore == "last":
            print("Reading model parameters from %s" % latest_dir)
            model.latest_saver.restore(sess,
                                       tf.train.latest_checkpoint(latest_dir))
        else:
            if tf.train.get_checkpoint_state(
                    best_dir) and args.restore == "best":
                print('Reading model parameters from %s' % best_dir)
                model.best_saver.restore(sess,
                                         tf.train.latest_checkpoint(best_dir))
            else:
                print("Created model with fresh parameters.")
                global_variable = [
                    gv for gv in tf.global_variables() if args.name in gv.name
                ]
                sess.run(tf.variables_initializer(global_variable))

    return model
示例#2
0
class Trainer(object):
    def __init__(self, lr=1e-4, **kwargs):
        self.lr = lr
        self.vae = VAEModel(**kwargs)
        self.data_path = "/datasets/Img/img_align_celeba/img_align_celeba"

    def train(self):

        step = 0
        with self.vae.g_train.as_default():
            kl, mse = self.vae.train_graph(self.data_path)
            loss = kl + mse

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optim = tf.train.AdamOptimizer(
                    learning_rate=self.lr).minimize(loss)

            tf.summary.scalar('KLLoss', kl)
            tf.summary.scalar('MSELoss', mse)
            tf.summary.scalar('TotalLoss', loss)

            merged = tf.summary.merge_all()
            writer = tf.summary.FileWriter("/tmp/log", self.vae.g_train)

            saver = tf.train.Saver()
            with tf.Session() as s:
                s.run(tf.global_variables_initializer())

                while True:
                    try:
                        _, kl_d, mse_d, loss_d, mgd = s.run(
                            [optim, kl, mse, loss, merged])
                        step += 1
                        writer.add_summary(mgd, step)

                        if step % 1000 == 0:
                            print("Saving model...")
                            saver.save(s, "./saved_models/model.ckpt")
                        print(
                            "Iter: {:10d}, KLLoss: {:12.6f}, MSELoss: {:12.6f}, TotalLoss: {:12.6f}"
                            .format(step, kl_d, mse_d, loss_d),
                            end="\n")
                    except tf.errors.OutOfRangeError:
                        print("Training finished...")
                        break
示例#3
0
 def __init__(self, lr=1e-4, **kwargs):
     self.lr = lr
     self.vae = VAEModel(**kwargs)
     self.data_path = "/datasets/Img/img_align_celeba/img_align_celeba"
示例#4
0
if args.task == 'depth':
    train_ds, test_ds = import_data.load_dataset_for_training_decoder(args.dset_file, args.num_imgs, args.batch_size, 'depths')
elif args.task == 'segmentation':
    train_ds, test_ds = import_data.load_dataset_for_training_decoder(args.dset_file, args.num_imgs, args.batch_size, 'segs')
else: #args.task == 'con2rgb'
    train_ds, test_ds = import_data.load_dataset_for_training_decoder(args.dset_file, args.num_imgs, args.batch_size, 'rgbs')

# prepare boolean vars for decoding layers
trainable_layers = [True]*5
frozen_layers = 5 - args.trainable_layers
for i in range(frozen_layers):
    trainable_layers[i] = False

# create model, loss and optimizer
if args.task == 'depth':
    model = VAEModel(n_z=args.n_z, trainable_encoder=False, trainable_decoder=trainable_layers, out_channels=1)
    model.load_weights(args.model_path)
elif args.task == 'segmentation':
    model = VAEModel(n_z=args.n_z, trainable_encoder=False, trainable_decoder=trainable_layers, out_channels=3)
    model.load_weights(args.model_path)
else: # args.task == 'con2rgb'
    model = VAEModel(n_z=args.n_z, trainable_encoder=True, trainable_decoder=trainable_layers, out_channels=3)
    model.load_decoder_weights_only(args.model_path)
optimizer = tf.keras.optimizers.Adam()

# create text log file
log_file = open(os.path.join(args.output_dir,"log.txt"),"w") 
log_file.write("Epoch\tTrainLoss\tTestLoss\n")

# define metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
示例#5
0
# add special tokens
special_tokens_dict = {
    'pad_token': '<|startoftext|>',
    'cls_token': '<|startofcond|>',
    'sep_token': '<|sepofcond|>',
    'mask_token': '<|endofcond|>'
}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'special tokens')
# Notice: resize_token_embeddings expect to receive the full size of the new vocab
gpt2_model.resize_token_embeddings(len(tokenizer))
assert tokenizer.pad_token == '<|startoftext|>'

VAE = VAEModel(config,
               add_input=add_input,
               add_attn=add_attn,
               add_softmax=add_softmax)
init_para_frompretrained(VAE.transformer,
                         gpt2_model.transformer,
                         share_para=True)
init_para_frompretrained(VAE.encoder, gpt2_model.transformer, share_para=False)
VAE.lm_head.weight = gpt2_model.lm_head.weight
if VAE.add_softmax:
    VAE.lm_head_rep = Conv1D(*gpt2_model.lm_head.weight.size())
print('VAE_params:', num_params(VAE))  # 286694400
print('Done.')

print('Loading model weights...')
state = torch.load(os.path.join(args.load, 'model_latest.pt'),
                   map_location='cpu')
if 'module' in list(state.keys(
示例#6
0
def main_worker(gpu, ngpus_per_node, args):
    if args.model_type == 'cvae':
        args.learn_prior = True
    else:
        args.learn_prior = False

    # GPU
    args.gpu = gpu
    print("There are ", torch.cuda.device_count(), " available GPUs!")
    # print('Setting GPUs {}'.format(args.device))
    print('Using GPU devices {}'.format(devices))
    device = torch.device('cuda', args.gpu)
    torch.cuda.set_device(device)
    print('Current single GPU: {}'.format(torch.cuda.current_device()))

    # randomness
    np.random.seed(args.seed)
    prng = np.random.RandomState()
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # For multiprocessing distributed training, rank needs to be the global rank among all the processes
    args.rank = args.rank * ngpus_per_node + gpu
    print('Setting rank', args.rank)
    recon_attempt = 1
    connected = False
    if args.rank != 0:
        # Stall to have rank 0 node go first
        time.sleep(3)
    while not connected:
        try:
            dist.init_process_group(backend=args.dist_backend,
                                    init_method=args.dist_url,
                                    world_size=args.world_size,
                                    rank=args.rank)
            connected = True
            print('Established connection. Rank:', args.rank)
        except Exception as e:
            # Sometimes the head node launches after the worker, which would cause an issue
            print('Failed to init process group. Retrying...', recon_attempt,
                  e)
            recon_attempt += 1
            time.sleep(10)

    # logging
    if args.rank == 0:
        save_folder = os.path.join(args.out_dir, args.experiment)
        os.makedirs(save_folder, exist_ok=True)
        t_writer = SummaryWriter(os.path.join(save_folder, 'train'),
                                 flush_secs=5)
        v_writer = SummaryWriter(os.path.join(save_folder, 'val'),
                                 flush_secs=5)
        importlib.reload(logging)
        logging.basicConfig(filename=os.path.join(save_folder, 'train.log'),
                            level=logging.INFO,
                            format='%(asctime)s--- %(message)s')
        logging.info(
            '\n*******************************************************************************\n'
        )
        logging.info("the configuration:")
        logging.info(str(args).replace(',', '\n'))

    print('Loading models...')
    cache_dir = os.path.join(args.out_dir, 'model_cache')
    os.makedirs(cache_dir, exist_ok=True)
    # Load pre-trained teacher tokenizer (vocabulary)
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
    # Hack to allow tokenizing longer sequences.
    tokenizer.max_len = int(1e12)
    gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=cache_dir)
    print('gpt2_params:', num_params(gpt2_model))  # gpt2: 124439808
    config = GPT2Config()

    # add special tokens
    # special_tokens_dict = {
    #     'pad_token': '<|startoftext|>',
    #     'cls_token': '<|startofcond|>',
    #     'sep_token': '<|sepofcond|>',
    #     'mask_token': '<|endofcond|>'
    # }
    # num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    # print('We have added', num_added_toks, 'special tokens')
    # # Notice: resize_token_embeddings expect to receive the full size of the new vocab
    # gpt2_model.resize_token_embeddings(len(tokenizer))
    # assert tokenizer.pad_token == '<|startoftext|>'

    VAE = VAEModel(config,
                   add_input=args.add_input,
                   add_attn=args.add_attn,
                   add_softmax=args.add_softmax,
                   attn_proj_vary=args.attn_proj_vary,
                   learn_prior=args.learn_prior)
    init_para_frompretrained(VAE.transformer,
                             gpt2_model.transformer,
                             share_para=True)
    init_para_frompretrained(VAE.encoder,
                             gpt2_model.transformer,
                             share_para=False)
    if args.learn_prior:
        init_para_frompretrained(VAE.encoder_prior,
                                 VAE.encoder,
                                 share_para=True)
        VAE.encoder_prior.averageSelfAttention.attention_weights = VAE.encoder.averageSelfAttention.attention_weights
    VAE.lm_head.weight = gpt2_model.lm_head.weight
    if VAE.add_softmax:
        VAE.lm_head_rep = Conv1D(*gpt2_model.lm_head.weight.size())
        # VAE.lm_head_rep = LM_head_rep(*gpt2_model.lm_head.weight.size()[::-1])
    print('VAE_params:', num_params(VAE))  # 286694400
    if args.load:
        print('Loading model weights...')
        state = torch.load(os.path.join(
            args.load, 'model_latest.pt'))  # , map_location='cpu'
        if 'module' in list(state.keys(
        ))[0]:  # model_path is data parallel model with attr 'module'
            state_copy = copy.copy(state)
            keys = state_copy.keys()
            for k in keys:
                state[k.replace('module.', '')] = state.pop(k)
        VAE.load_state_dict(state)
        gc.collect()
    print('Done.')

    # fix pre-trained parameters before certain iterations
    tuning_all_after_iters = 10000
    tuning_all = False
    for name, parameter in VAE.named_parameters():
        # print((name, parameter.requires_grad))
        new_pars = [
            'c_z', 'attention_weights', 'mean', 'logvar', 'input_proj',
            'attn_proj', 'Nu_fc1', 'Nu_fc2', 'lm_head_rep'
        ]

        if not any([True if n in name else False for n in new_pars]):
            parameter.requires_grad = False

    print('Setup data...')
    # Batch and sequence length schedule
    assert len(args.batch_sizes) == len(args.seq_lens)
    batch_schedule = list(
        zip(map(int, args.batch_sizes), map(int, args.seq_lens)))
    assert len(
        batch_schedule) <= 2, 'Currently not supporting multiple schedule'
    cur_b_schedule = len(batch_schedule) - 1 if args.switch_time == 0 else 0
    print('Batch schedule', batch_schedule)
    train_loader, val_loader, test_loader = prepare_dataset(
        args.data_dir,
        args.dataset,
        tokenizer,
        batch_schedule[cur_b_schedule][0],
        batch_schedule[cur_b_schedule][1],
        batch_schedule[-1][0],
        batch_schedule[-1][1],
        batch_schedule[-1][0],
        batch_schedule[-1][1],
        make_test=True,
        num_workers=args.workers,
        data_type=args.data_type)
    print('Done.')

    ###
    val_loader = test_loader
    ###

    print('Wrapping models and optimizers...')
    # Apply linear scaling rule to increase batch size for short sequence training.
    lr_schedule = switch_schedule(
        linear_schedule(args),
        batch_schedule[cur_b_schedule][0] / batch_schedule[-1][0],
        int(args.iterations * args.switch_time))
    VAE = VAE.to(device)
    VAE = VAE.train()

    optimizer = AdamW(VAE.parameters(), lr=args.lr, correct_bias=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
    VAE, optimizer = amp.initialize(VAE,
                                    optimizer,
                                    opt_level=args.fp16_opt_level)
    loss_model = DDP(VAE)  # , delay_allreduce=True

    loss_fn = nn.CrossEntropyLoss(reduction='none')
    print('Done.')

    print('Begin training iterations')
    logging.info("Begin training iterations")
    max_val_batches = 20000  # max num. of val batches
    logging.info("Total iteration: %d" % args.iterations)
    e = 0  # number of epoch
    num_iters = 0
    optimizer.zero_grad()
    beta = args.beta_0
    endoftext = tokenizer.convert_tokens_to_ids("<|endoftext|>")

    def val_step(val_loader):
        VAE.eval()

        n_words_bpe = 0
        n_words = 0
        logp_sum = 0.0
        kl_loss_sum = 0.0

        logging.info("Validation loop.         Batches: %d" % len(val_loader))
        logging.info("Validation loop. max_val_batches: %d" % max_val_batches)

        # val_iter = iter(val_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(val_iter)
        with tqdm(total=min(len(val_loader), max_val_batches)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(val_loader):
                with torch.no_grad():
                    if args.model_type == 'cvae':
                        loss, ce_loss, kl_loss = compute_loss(
                            device, VAE, x_mask, x_tokens, y_mask, y_tokens,
                            input_tokens, target_tokens, mask, loss_fn, 1.0)
                    else:
                        loss, ce_loss, kl_loss = compute_loss_ae(
                            device, VAE, x_mask, x_tokens, y_mask, y_tokens,
                            input_tokens, target_tokens, mask, loss_fn, 1.0)

                if len(target_tokens.size()) == 1:
                    target_tokens = target_tokens.unsqueeze(0)
                n, l = target_tokens.size()

                text = target_tokens[0, :].tolist()
                logprob = ce_loss.tolist()
                assert len(text) == len(logprob)

                # only for story
                idx = text.index(endoftext)
                text = text[idx + 1:]
                logprob = logprob[idx + 1:]

                if endoftext in text:
                    idx = text.index(endoftext)
                    text = text[:idx]
                    logprob = logprob[:idx]

                logp_sum += sum(logprob)

                n_words_bpe += len(text)

                story = [
                    tokenizer.decode(target_tokens[i, :]) for i in range(n)
                ]
                story = [
                    s[s.find("<|endoftext|>") + len("<|endoftext|>"):]
                    for s in story
                ]
                story = [
                    s[:s.find("<|endoftext|>") +
                      len("<|endoftext|>")] if "<|endoftext|>" in s else s
                    for s in story
                ]
                words = sum([
                    len([
                        t for t in re.split(
                            '("|\'|!|\?|\.|,|:| |\n|’|“|”|;|\(|\)|`)', s)
                        if t != ' ' and t != ''
                    ]) for s in story
                ])
                n_words += words

                kl_loss_sum += kl_loss.item()

                if i > max_val_batches:
                    break
                pbar.update(1)

        loss_bpe = logp_sum / n_words_bpe
        ppl_bpe = round(math.exp(min(logp_sum / n_words_bpe, 100)), 3)
        ppl_word = round(math.exp(min(logp_sum / n_words, 100)), 3)
        kl = kl_loss_sum / len(val_loader)

        v_writer.add_scalar('loss', loss_bpe, num_iters)
        v_writer.add_scalar('ppl_bpe', ppl_bpe, num_iters)
        v_writer.add_scalar('ppl_word', ppl_word, num_iters)
        v_writer.add_scalar('kl', kl, num_iters)
        logging.info('val loss    : %.4f' % loss_bpe)
        logging.info('val ppl_bpe : %.4f' % ppl_bpe)
        logging.info('val ppl_word: %.4f' % ppl_word)
        logging.info('val   kl    : %.4f' % kl)

        VAE.train()

    def test_plot(test_loader, num_iters):
        VAE.eval()

        # get embedding
        X_emb = None
        y = None

        # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter)
        with tqdm(total=len(test_loader)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(test_loader):
                y_mask = y_mask.to(device)
                y_tokens = y_tokens.to(device)
                x_mask = x_mask.to(device)
                x_tokens = x_tokens.to(device)
                with torch.no_grad():
                    if args.model_type == 'cvae':
                        latent_mean, _ = VAE.encoder_prior(
                            input_ids=x_tokens, attention_mask=x_mask)[:2]
                    else:
                        latent_mean, _ = VAE.encoder(input_ids=x_tokens,
                                                     attention_mask=x_mask)[:2]

                if args.dataset == 'ax' or args.dataset == 'yp':
                    label = [
                        tokenizer.decode(l)[:2] for l in x_tokens.tolist()
                    ]
                elif args.dataset == 'wp':
                    label = []
                    prompts = [
                        tokenizer.decode(l)[:6].lower()
                        for l in x_tokens.tolist()
                    ]
                    for prom in prompts:
                        if prom[0] in ['[', '('] and prom[5] in [']', ')']:
                            label.append(prom[2:4])
                        else:
                            label.append(None)
                elif args.dataset == 'wi':
                    # 0. TV, play, miniseries, telenovela; 1.film; 2. music; 3. manga, comic, 4. book, novel, story 5. game
                    label = []
                    prompts = [tokenizer.decode(l) for l in x_tokens.tolist()]
                    for prom in prompts:
                        if 'TV' in prom or 'play' in prom or 'miniseries' in prom or 'telenovela' in prom:
                            label.append(0)
                        elif 'film' in prom:
                            label.append(1)
                        elif 'music' in prom:
                            label.append(2)
                        elif 'manga' in prom or 'comic' in prom:
                            label.append(3)
                        elif 'book' in prom or 'novel' in prom or 'story' in prom:
                            label.append(4)
                        elif 'game' in prom:
                            label.append(5)
                        else:
                            label.append(None)
                else:
                    raise Exception

                if i == 0:
                    X_emb = latent_mean.data
                    y = label
                else:
                    X_emb = torch.cat((X_emb, latent_mean.data), dim=0)
                    y.extend(label)
                pbar.update(1)
        X_emb = X_emb.cpu().numpy()

        try:
            if args.dataset == 'yp':
                y = ['0' if l in ['0', '1'] else l for l in y]
                y = ['4' if l in ['3', '4'] else l for l in y]
                X_emb = X_emb[[l != '2' for l in y], :]
                y = [l for l in y if l != '2']

            if args.dataset == 'wp':
                topics = [['wp', 'sp', 'tt'], ['eu'], ['cw'], ['pm'],
                          ['mp', 'ip'], ['pi', 'cc'], ['ot'], ['rf']]
                match = [[True if l in t else False for t in topics]
                         for l in y]
                y = [m.index(True) if True in m else None for m in match]
                X_emb = X_emb[[l is not None for l in y], :]
                y = [l for l in y if l is not None]

            if args.dataset == 'wi':
                X_emb = X_emb[[l is not None for l in y], :]
                y = [l for l in y if l is not None]

            # to 2D
            # X_emb_2d = TSNE(n_components=2, init='pca', verbose=1).fit_transform(X_emb)
            X_emb_2d = TSNE(n_components=2, verbose=1,
                            perplexity=40).fit_transform(X_emb)

            def remove_outliers(data, r=2.0):
                outliers_data = abs(
                    data - np.mean(data, axis=0)) >= r * np.std(data, axis=0)
                outliers = np.any(outliers_data, axis=1)
                keep = np.logical_not(outliers)
                return outliers, keep

            outliers, keep = remove_outliers(X_emb_2d)
            X_emb_2d = X_emb_2d[keep, :]
            y = [l for l, k in zip(y, keep.tolist()) if k]

            # plot
            fig = plt.figure(figsize=(4, 4))
            ax = fig.add_axes([0, 0, 1, 1])
            cc = ['r', 'b', 'g', 'y', 'k', 'c', 'm', 'tab:blue']
            for i, l in enumerate(sorted(set(y))):
                idx = [yl == l for yl in y]
                plt.scatter(X_emb_2d[idx, 0],
                            X_emb_2d[idx, 1],
                            c=cc[i],
                            s=10,
                            edgecolor='none',
                            alpha=0.5)
            ax.axis('off')  # adding it will get no axis
            plt.savefig(
                os.path.join(save_folder,
                             'tSNE_' + '{:07d}'.format(num_iters) + '.png'))
            plt.close(fig)
        except:
            pass

        VAE.train()

    def generate(test_loader, num_iters):
        VAE.eval()

        n_samples = 0
        bleu4_sum = 0.0
        rouge_scores_values_sum = [0.0] * 9

        args.nsamples = 1
        args.batch_size = 1
        args.temperature = 0.95
        args.top_k = 100
        args.top_p = 0.95
        model_type = args.model_type

        # write samples to file
        samples_file = open(os.path.join(
            save_folder, 'generate-' + '%07d' % num_iters + '.txt'),
                            'w',
                            encoding='utf8')

        # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter)
        with tqdm(total=len(test_loader)) as pbar:
            for i_test, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                         target_tokens, mask) in enumerate(test_loader):

                if i_test >= 10: break

                length = -1
                if length == -1:
                    length = VAE.config.n_ctx - x_tokens.size(1) - 1
                elif length > VAE.config.n_ctx - x_tokens.size(1) - 1:
                    raise ValueError(
                        "Can't get samples longer than window size: %s" %
                        VAE.config.n_ctx)

                eff_samples = []
                n, l = target_tokens.size()
                storys = [
                    tokenizer.decode(target_tokens[i, :]) for i in range(n)
                ]
                storys = [
                    s[s.find("<|endoftext|>") + len("<|endoftext|>"):]
                    for s in storys
                ]
                storys_str = [
                    s[:s.find("<|endoftext|>") +
                      len("<|endoftext|>")] if "<|endoftext|>" in s else s
                    for s in storys
                ]

                for _ in range(args.nsamples // args.batch_size):
                    # model, batch_size, temperature, top_k, top_p, eos_token, sample = VAE, args.batch_size, args.temperature, args.top_k, args.top_p, tokenizer.encoder['<|endoftext|>'], True
                    out, _ = sample_sequence(
                        model=VAE,
                        tokenizer=tokenizer,
                        length=length,
                        batch_size=args.batch_size,
                        x_mask=x_mask,
                        x_tokens=x_tokens,
                        y_mask=y_mask,
                        y_tokens=y_tokens,
                        temperature=args.temperature,
                        top_k=args.top_k,
                        top_p=args.top_p,
                        device=device,
                        eos_token=tokenizer.encoder['<|endoftext|>'],
                        model_type=model_type)
                    out = out.tolist()

                    # extract story, check metrics
                    for i in range(len(out)):
                        text = out[i]
                        text = text[text.index(endoftext) + 1:]

                        if endoftext in text:
                            idx = text.index(endoftext)
                            text = text[:idx]

                        text = tokenizer.decode(text).strip()

                        # score for one long text, higher than 0.075 usually means repetition
                        # rep_score = repeat_score(text.split(), ngram=[3, 4, 5, 6, 7, 8])
                        # if rep_score > 0.075:
                        #     # print(rep_score)
                        #     continue

                        try:
                            # check bleu
                            bleu4 = sentence_bleu(
                                [storys_str[i].split()],
                                text,
                                smoothing_function=SmoothingFunction().method7)

                            # check rouge
                            rouge = Rouge()
                            rouge_scores = rouge.get_scores(
                                text, storys_str[i])
                            rouge_scores_values = [
                                v for k in rouge_scores[0].keys()
                                for v in rouge_scores[0][k].values()
                            ]

                            bleu4_sum += bleu4
                            rouge_scores_values_sum = [
                                v1 + v2
                                for v1, v2 in zip(rouge_scores_values_sum,
                                                  rouge_scores_values)
                            ]
                            n_samples += 1
                        except:
                            bleu4 = 0.0
                            rouge_scores = [{
                                'rouge-1': {
                                    'f': 0.0,
                                    'p': 0.0,
                                    'r': 0.0
                                },
                                'rouge-2': {
                                    'f': 0.0,
                                    'p': 0.0,
                                    'r': 0.0
                                },
                                'rouge-l': {
                                    'f': 0.0,
                                    'p': 0.0,
                                    'r': 0.0
                                }
                            }]

                        eff_samples.append((text, bleu4, rouge_scores))

                    pbar.update(1)

                for i in range(len(eff_samples)):
                    samples_file.write("=" * 50 + " SAMPLE " + str(i_test) +
                                       " " + "=" * 50)
                    samples_file.write('\n' * 2)

                    samples_file.write("=" * 40 + " Outlines  " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(
                        tokenizer.decode(
                            x_tokens[i, :][x_mask[i, :] == 1].tolist()))
                    samples_file.write('\n' * 2)
                    samples_file.write("=" * 40 + " Story " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(storys_str[i])
                    samples_file.write('\n' * 2)

                    samples_file.write("=" * 40 + " Generated " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(eff_samples[i][0])
                    samples_file.write('\n' * 4)
                    samples_file.flush()

        print('Test complete with %05d samples.' % n_samples)
        logging.info("Test complete with %05d samples.", n_samples)
        logging.info("Iteration completed: %d" % num_iters)

        bleu4 = round(bleu4_sum / n_samples, 3)
        rouge_scores_values = [
            round(r / n_samples, 3) for r in rouge_scores_values_sum
        ]
        print(' bleu-4:', bleu4)
        print(' rouge :', rouge_scores_values)
        logging.info(' bleu-4: %f', bleu4)
        logging.info(' rouge : %s', str(rouge_scores_values))

        VAE.train()

    if args.rank == 0:
        test_plot(test_loader, num_iters)
        val_step(val_loader)
        generate(test_loader, num_iters)
        torch.save(
            VAE.state_dict(),
            os.path.join(save_folder,
                         'model_' + '{:07d}'.format(num_iters) + '.pt'))

    while num_iters < args.iterations:
        # Run epoch
        st = time.time()

        # Training
        print('Training loop. Batches:', len(train_loader))
        logging.info(
            '\n----------------------------------------------------------------------'
        )
        logging.info("Training loop.       Batches: %d" % len(train_loader))

        # train_iter = iter(train_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(train_iter)
        with tqdm(total=len(train_loader)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(train_loader):

                if num_iters % args.cycle >= args.cycle - args.beta_warmup - 25000:
                    beta = min(1.0,
                               beta + (1. - args.beta_0) / args.beta_warmup)

                if not tuning_all and num_iters >= tuning_all_after_iters:
                    for name, parameter in VAE.named_parameters():
                        # print((name, parameter.requires_grad))
                        parameter.requires_grad = True
                    tuning_all = True

                output = train_step(device, loss_model, optimizer, x_mask,
                                    x_tokens, y_mask, y_tokens, input_tokens,
                                    target_tokens, mask, loss_fn, beta,
                                    args.model_type)

                if args.rank == 0:
                    loss, ce_loss, kl_loss = output[-1]
                    lr = scheduler.get_last_lr()[0]
                    # Log to Tensorboard
                    t_writer.add_scalar('loss', loss, num_iters)
                    t_writer.add_scalar('ppl', math.exp(min(ce_loss, 10)),
                                        num_iters)
                    t_writer.add_scalar('lr', lr, num_iters)
                    t_writer.add_scalar('iter_time',
                                        time.time() - st, num_iters)
                    t_writer.add_scalar('kl', kl_loss, num_iters)
                    t_writer.add_scalar('beta', beta, num_iters)

                    if args.model_type == 'ae_vae_fusion':
                        loss, ce_loss, kl_loss = output[0]
                        # Log to Tensorboard
                        t_writer.add_scalar('ae_loss', loss, num_iters)
                        t_writer.add_scalar('ae_kl', kl_loss, num_iters)

                st = time.time()
                end = num_iters >= args.iterations

                if args.warmup != -1:
                    scheduler.step()

                if end: break
                num_iters += 1
                pbar.update(1)

                if num_iters % args.cycle == 0:
                    beta = args.beta_0
                    logging.info('KL annealing restart')

                if args.rank == 0 and num_iters % 10000 == 0:
                    test_plot(test_loader, num_iters)
                    val_step(val_loader)
                    generate(test_loader, num_iters)

                if args.rank == 0 and num_iters % 10000 == 0:
                    print('Saving model...')
                    logging.info("Iteration completed: %d, remained %d" %
                                 (num_iters, args.iterations - num_iters))
                    logging.info("Saving model...")
                    logging.info(
                        '\n------------------------------------------------------'
                    )
                    torch.save(
                        VAE.state_dict(),
                        os.path.join(
                            save_folder,
                            'model_' + '{:07d}'.format(num_iters) + '.pt'))

                if args.switch_time > 0 and num_iters == int(
                        args.iterations * args.switch_time):
                    print('Switch to long sequence training')
                    logging.info("Switch to long sequence training")
                    cur_b_schedule += 1
                    train_loader, val_loader, test_loader = prepare_dataset(
                        args.data_dir,
                        args.dataset,
                        tokenizer,
                        batch_schedule[cur_b_schedule][0],
                        batch_schedule[cur_b_schedule][1],
                        batch_schedule[-1][0],
                        batch_schedule[-1][1],
                        batch_schedule[-1][0],
                        batch_schedule[-1][1],
                        make_test=True,
                        num_workers=args.workers,
                        data_type=args.data_type)
        if not end:
            e += 1
            logging.info("Training loop. The ith epoch completed: %d" % e)

    if args.rank == 0:
        torch.save(VAE.state_dict(),
                   os.path.join(save_folder, 'model_latest.pt'))
    print('Training complete.')
    logging.info("Training complete.")
示例#7
0
parser.add_argument('--show', '-show', help='choose what to do from [predict, inter]', default='predict', type=str)
parser.add_argument('--method', '-method', help='choose what to do from [restoration, depth, con2rgb]', default='restoration', type=str)
parser.add_argument('--grayscale', '-grayscale', dest='grayscale', help='choose for training on grayscale images', action='store_true')
args = parser.parse_args()


# tf function for prediction
@tf.function
def predict_image(image, inter=None):
    return model(image, inter)

# allow growth is possible using an env var in tf2.0
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

# Load the model
model = VAEModel(n_z=args.n_z)
model.load_weights(args.path)

# Load image
im = Image.open(args.img)
if args.res > 0:
    im = im.resize((args.res, args.res), Image.ANTIALIAS)
input_img = (np.expand_dims(np.array(im), axis=0) / 255.0).astype(np.float32)
if len(input_img.shape) < 4:
    input_img = np.expand_dims(np.array(input_img), axis=-1)
    input_img = np.repeat(input_img, 3, axis=-1)

if args.show == 'inter': # interpolate over the latent space

    latent_array = np.array([-0.02, -0.01, 0.0, 0.01, 0.02])
示例#8
0
    if y_test[i] == 9:
        temp = x_test[i, :, :, :]
        x_test_9.append(
            temp.reshape(
                (x_train_shape[1], x_train_shape[2], x_train_shape[3])))

x_test_1 = np.array(x_test_1)
x_test_9 = np.array(x_test_9)

# network parameters
input_shape = (8, 8, 1)
batch_size = 128
epochs = 5

# loading model
model = VAEModel(input_shape)
vae = model.load_model(input_shape)

# train the autoencoder
vae.fit(x_train_1, epochs=epochs, batch_size=batch_size)
#validation_data=(x_test, None))
vae.save('vae_mlp_mnist.h5', include_optimizer=False)

#正常/異常のテストデータ
idx1 = np.random.randint(len(x_test_1))
idx2 = np.random.randint(len(x_test_9))

test_normal = x_test_1[idx1, :, :, :]
test_anomaly = x_test_9[idx2, :, :, :]

test_normal = test_normal.reshape(1, test_normal.shape[0],
示例#9
0
    test_kl_loss(kl_loss)


# allow growth is possible using an env var in tf2.0
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

# get train and test datasets
train_ds, test_ds = import_data.load_dataset_for_frame_restoration(
    args.data_file,
    args.num_imgs,
    args.batch_size,
    random_sample=args.random_sample)

# create model, loss and optimizer
model = VAEModel(n_z=args.n_z, out_channels=3)
optimizer = tf.keras.optimizers.Adam()

# define metrics
train_rec_loss = tf.keras.metrics.Mean(name='train_rec_loss')
train_kl_loss = tf.keras.metrics.Mean(name='train_kl_loss')
test_rec_loss = tf.keras.metrics.Mean(name='test_rec_loss')
test_kl_loss = tf.keras.metrics.Mean(name='test_kl_loss')
metrics_writer = tf.summary.create_file_writer(args.output_dir)

# check if output folder exists
if not os.path.isdir(args.output_dir):
    os.makedirs(args.output_dir)

# train
print('Start training...')