コード例 #1
0
def run(args, local_rank):
    """ Distributed Synchronous """
    torch.manual_seed(1234)
    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
    if (args.world_size == 1 or dist.get_rank() == 0):
        print ("vocab.size = %d"%vocab.size, flush=True)
    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\
                  args.num_heads, args.dropout, args.layers, args.smoothing, args.approx)
    if args.start_from is not None:
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
    model = model.cuda(local_rank)
   
    if args.world_size > 1:
        torch.manual_seed(1234 + dist.get_rank())
        random.seed(5678 + dist.get_rank())
    
    optimizer = Optim(model.embed_dim, args.lr, args.warmup_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.998), eps=1e-9))

    if args.start_from is not None:
        optimizer.load_state_dict(ckpt['optimizer'])

    #train_data = DataLoader(vocab, args.train_data+"0"+str(local_rank), args.batch_size, args.max_len, args.min_len)
    train_data = DataLoader(vocab, args.train_data, args.batch_size, args.max_len, args.min_len)
    batch_acm = 0
    acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.
    while True:
        model.train()
        for truth, inp, msk in train_data:
            batch_acm += 1
            truth = truth.cuda(local_rank)
            inp = inp.cuda(local_rank)
            msk = msk.cuda(local_rank)

            model.zero_grad()
            res, loss, acc, nll, ppl, ntokens, npairs = model(truth, inp, msk)
            loss_acm += loss.item()
            acc_acm += acc
            nll_acm += nll
            ppl_acm += ppl
            ntokens_acm += ntokens
            npairs_acm += npairs
            nxs += npairs
            
            loss.backward()
            if args.world_size > 1:
                average_gradients(model)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.print_every == -1%args.print_every:
                print ('batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'\
                        %(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, \
                        nll_acm/nxs, ppl_acm/nxs, npairs_acm, optimizer._rate), flush=True)
                acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.
            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.save_every == -1%args.save_every:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                torch.save({'args':args, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}, '%s/epoch%d_batch_%d'%(args.save_dir, train_data.epoch_id, batch_acm))
コード例 #2
0
start_time = time.time()
for epoch in range(1, args.max_iter + 1):
    trn_loader = Data.get_batches(Data.trn_set,
                                  batch_size=args.batch_size,
                                  shuffle=True)
    for bidx in range(n_batchs):
        model.train()
        inputs = next(trn_loader)
        X_p, X_f, Y_true = inputs[0], inputs[1], inputs[2]

        model.zero_grad()
        X_p_dec = model(X_p)
        loss = torch.mean((X_p - X_p_dec)**2)

        loss.backward()
        optimizer.step()
        update += 1
        if update % args.eval_freq == 0:
            # ========= Main block for evaluate MMD(X_p_enc, X_f_enc) on RNN codespace  =========#
            val_dict = valid_epoch(Data, Data.val_set, model, args.batch_size,
                                   Y_val, L_val, args.model)
            tst_dict = valid_epoch(Data, Data.tst_set, model, args.batch_size,
                                   Y_tst, L_tst, args.model)
            total_time = time.time() - start_time
            print(
                'iter %4d tm %4.2fm trn_loss %.4e val_mse %.4f val_mae %.4f val_auc %.6f'
                % (epoch, total_time / 60.0, loss.data[0], val_dict['mse'],
                   val_dict['mae'], val_dict['auc']),
                end='')

            print(" tst_mse %.4f tst_mae %.4f tst_auc %.6f" %
コード例 #3
0
def main(hparams: HParams):
    '''
    setup training.
    '''
    if torch.cuda.is_available() and not hparams.gpus:
        warnings.warn(
            'WARNING: you have a CUDA device, so you should probably run with -gpus 0'
        )

    device = torch.device(hparams.gpus if torch.cuda.is_available() else 'cpu')

    # data setup
    print(f"Loading vocabulary...")
    text_preprocessor = TextPreprocessor.load(hparams.preprocessor_path)

    transform = transforms.Compose([
        transforms.Resize([hparams.img_size, hparams.img_size]),
        transforms.RandomCrop([hparams.crop_size, hparams.crop_size]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # create dataloader
    print('Creating DataLoader...')
    normal_data_loader = get_image_caption_loader(
        hparams.img_dir,
        hparams.normal_caption_path,
        text_preprocessor,
        hparams.normal_batch_size,
        transform,
        shuffle=True,
        num_workers=hparams.num_workers,
    )

    style_data_loader = get_caption_loader(
        hparams.style_caption_path,
        text_preprocessor,
        batch_size=hparams.style_batch_size,
        shuffle=True,
        num_workers=hparams.num_workers,
    )

    if hparams.train_from:
        # loading checkpoint
        print('Loading checkpoint...')
        checkpoint = torch.load(hparams.train_from)
    else:
        normal_opt = Optim(
            hparams.optimizer,
            hparams.normal_lr,
            hparams.max_grad_norm,
            hparams.lr_decay,
            hparams.start_decay_at,
        )
        style_opt = Optim(
            hparams.optimizer,
            hparams.style_lr,
            hparams.max_grad_norm,
            hparams.lr_decay,
            hparams.start_decay_at,
        )

    print('Building model...')
    encoder = EncoderCNN(hparams.hidden_dim)
    decoder = FactoredLSTM(hparams.embed_dim,
                           text_preprocessor.vocab_size,
                           hparams.hidden_dim,
                           hparams.style_dim,
                           hparams.num_layers,
                           hparams.random_init,
                           hparams.dropout_ratio,
                           train=True,
                           device=device)

    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=text_preprocessor.PAD_ID)
    normal_params = list(encoder.parameters()) + list(
        decoder.default_parameters())
    style_params = list(decoder.style_parameters())
    normal_opt.set_parameters(normal_params)
    style_opt.set_parameters(style_params)

    if hparams.train_from:
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        normal_opt.load_state_dict(checkpoint['normal_opt'])
        style_opt.load_state_dict(checkpoint['style_opt'])

    # traininig loop
    print('Start training...')
    for epoch in range(hparams.num_epoch):

        # result
        sum_normal_loss, sum_style_loss, sum_normal_ppl, sum_style_ppl = 0, 0, 0, 0

        # normal caption
        for i, (images, in_captions, out_captions,
                lengths) in enumerate(normal_data_loader):
            images = images.to(device)
            in_captions = in_captions.to(device)
            out_captions = out_captions.contiguous().view(-1).to(device)

            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(in_captions, features, mode='default')
            loss = criterion(outputs.view(-1, outputs.size(-1)), out_captions)
            encoder.zero_grad()
            decoder.zero_grad()
            loss.backward()
            normal_opt.step()

            # print log
            sum_normal_loss += loss.item()
            sum_normal_ppl += np.exp(loss.item())
            if i % hparams.normal_log_step == 0:
                print(
                    f'Epoch [{epoch}/{hparams.num_epoch}], Normal Step: [{i}/{len(normal_data_loader)}] '
                    f'Normal Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}'
                )

        # style caption
        for i, (in_captions, out_captions,
                lengths) in enumerate(style_data_loader):
            in_captions = in_captions.to(device)
            out_captions = out_captions.contiguous().view(-1).to(device)

            # Forward, backward and optimize
            outputs = decoder(in_captions, None, mode='style')
            loss = criterion(outputs.view(-1, outputs.size(-1)), out_captions)

            decoder.zero_grad()
            loss.backward()
            style_opt.step()

            sum_style_loss += loss.item()
            sum_style_ppl += np.exp(loss.item())
            # print log
            if i % hparams.style_log_step == 0:
                print(
                    f'Epoch [{epoch}/{hparams.num_epoch}], Style Step: [{i}/{len(style_data_loader)}] '
                    f'Style Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}'
                )

        model_params = {
            'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'epoch': epoch,
            'normal_opt': normal_opt.optimizer.state_dict(),
            'style_opt': style_opt.optimizer.state_dict(),
        }

        avg_normal_loss = sum_normal_loss / len(normal_data_loader)
        avg_style_loss = sum_style_loss / len(style_data_loader)
        avg_normal_ppl = sum_normal_ppl / len(normal_data_loader)
        avg_style_ppl = sum_style_ppl / len(style_data_loader)
        print(f'Epoch [{epoch}/{hparams.num_epoch}] statistics')
        print(
            f'Normal Loss: {avg_normal_loss:.4f} Normal ppl: {avg_normal_ppl:5.4f} '
            f'Style Loss: {avg_style_loss:.4f} Style ppl: {avg_style_ppl:5.4f}'
        )

        torch.save(
            model_params,
            f'{hparams.model_path}/n-loss_{avg_normal_loss:.4f}_s-loss_{avg_style_loss:.4f}_'
            f'n-ppl_{avg_normal_ppl:5.4f}_s-ppl_{avg_style_ppl:5.4f}_epoch_{epoch}.pt'
        )
コード例 #4
0
ファイル: klcpd.py プロジェクト: siamakz/klcpd_code
            mmd2_real = mmd_util.batch_mmd2_loss(X_p_enc, X_f_enc, sigma_var)

            # reconstruction loss
            real_L2_loss = torch.mean((X_f - X_f_dec)**2)
            #real_L2_loss = torch.mean((X_p - X_p_dec)**2)
            fake_L2_loss = torch.mean((Y_f - Y_f_dec)**2)
            #fake_L2_loss = torch.mean((Y_f - Y_f_dec)**2) * 0.0

            # update netD
            netD.zero_grad()
            lossD = D_mmd2.mean() - lambda_ae * (
                real_L2_loss + fake_L2_loss) - lambda_real * mmd2_real.mean()
            #lossD = 0.0 * D_mmd2.mean() - lambda_ae * (real_L2_loss + fake_L2_loss) - lambda_real * mmd2_real.mean()
            #lossD = -real_L2_loss
            lossD.backward(mone)
            optimizerD.step()

        ############################
        # (2) Update G network
        ############################
        for p in netD.parameters():
            p.requires_grad = False  # to avoid computation

        if bidx == n_batchs:
            break

        inputs = next(trn_loader)
        X_p, X_f = inputs[0], inputs[1]
        batch_size = X_p.size(0)
        bidx += 1
コード例 #5
0
def run(args, local_rank):
    """ Distributed Synchronous """
    torch.manual_seed(1234)
    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
    if (args.world_size == 1 or dist.get_rank() == 0):
        print("vocab.size = " + str(vocab.size), flush=True)
    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,
                  args.num_heads, args.dropout, args.layers, args.smoothing)
    if args.start_from is not None:
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
    model = model.cuda(local_rank)

    optimizer = Optim(
        model.embed_dim, args.lr, args.warmup_steps,
        torch.optim.Adam(model.parameters(),
                         lr=0,
                         betas=(0.9, 0.998),
                         eps=1e-9))

    if args.start_from is not None:
        optimizer.load_state_dict(ckpt['optimizer'])

    train_data = DataLoader(vocab, args.train_data, args.batch_size,
                            args.max_len, args.min_len)
    batch_acm = 0
    acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.
    while True:
        model.train()
        if train_data.epoch_id > args.max_epoch:
            break
        for xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk in train_data:
            batch_acm += 1
            xs_tpl = xs_tpl.cuda(local_rank)
            xs_seg = xs_seg.cuda(local_rank)
            xs_pos = xs_pos.cuda(local_rank)
            ys_truth = ys_truth.cuda(local_rank)
            ys_inp = ys_inp.cuda(local_rank)
            ys_tpl = ys_tpl.cuda(local_rank)
            ys_seg = ys_seg.cuda(local_rank)
            ys_pos = ys_pos.cuda(local_rank)
            msk = msk.cuda(local_rank)

            model.zero_grad()
            res, loss, acc, nll, ppl, ntokens, npairs = model(
                xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg,
                ys_pos, msk)

            # http://www.myzaker.com/article/5f3747a28e9f096c723a65e0/ 资料
            # 常用的文本生成评测指标 PPL、Distinct 外,
            # 本文还专门设计了衡量格式(Format)准确率、韵律(Rhyme)准确率和句子完整性(integrity)的指标。
            # 格式(Format)准确率: Precision p、Recall r 和 F1 得分-> Macro-F1 和 Micro-F1
            # 完整性有个奇怪的log值
            # 传统的BLEU和ROUGE, 再songnet中完全用不到, 创作要求多样性
            loss_acm += loss.item()  # 损失
            acc_acm += acc  # 精确度
            nll_acm += nll  #
            ppl_acm += ppl  # -log 和, 其实就是句子出现的概率, 越小, 困惑度越高
            # 新指标, 困惑度perplexity, 比较两者再预测样本上的优劣, 困惑都越低越好??, 咋定义的
            ntokens_acm += ntokens  # 字符数
            npairs_acm += npairs  # 句子?
            nxs += npairs

            # 为什么啊, 感觉好难啊gpt2

            loss.backward()
            if args.world_size > 1:
                is_normal = average_gradients(model)
            else:
                is_normal = True
            if is_normal:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            else:
                print("gradient: none, gpu: " + str(local_rank), flush=True)
                continue
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.print_every == -1 % args.print_every:
                today = datetime.datetime.now()
                print(today)
                print(
                    'batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'
                    % (batch_acm, loss_acm / args.print_every,
                       acc_acm / ntokens_acm, nll_acm / nxs, ppl_acm / nxs,
                       npairs_acm, optimizer._rate),
                    flush=True)
                acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.save_every == -1 % args.save_every:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)

                model.eval()
                eval_epoch(
                    args, model, vocab, local_rank, "epoch-" +
                    str(train_data.epoch_id) + "-acm-" + str(batch_acm))
                model.train()

                torch.save(
                    {
                        'args': args,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, '%s/epoch%d_batch_%d' %
                    (args.save_dir, train_data.epoch_id, batch_acm))
コード例 #6
0
ファイル: train.py プロジェクト: xinyu12138/SongNet
def run(args, local_rank):
    """ Distributed Synchronous """
    torch.manual_seed(1234)
    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
    if (args.world_size == 1 or dist.get_rank() == 0):
        print("vocab.size = " + str(vocab.size), flush=True)
    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\
                  args.num_heads, args.dropout, args.layers, args.smoothing)
    if args.start_from is not None:
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
    model = model.cuda(local_rank)

    optimizer = Optim(
        model.embed_dim, args.lr, args.warmup_steps,
        torch.optim.Adam(model.parameters(),
                         lr=0,
                         betas=(0.9, 0.998),
                         eps=1e-9))

    if args.start_from is not None:
        optimizer.load_state_dict(ckpt['optimizer'])

    train_data = DataLoader(vocab, args.train_data, args.batch_size,
                            args.max_len, args.min_len)
    batch_acm = 0
    acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.
    while True:
        model.train()
        if train_data.epoch_id > 30:
            break
        for xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk in train_data:
            batch_acm += 1
            xs_tpl = xs_tpl.cuda(local_rank)
            xs_seg = xs_seg.cuda(local_rank)
            xs_pos = xs_pos.cuda(local_rank)
            ys_truth = ys_truth.cuda(local_rank)
            ys_inp = ys_inp.cuda(local_rank)
            ys_tpl = ys_tpl.cuda(local_rank)
            ys_seg = ys_seg.cuda(local_rank)
            ys_pos = ys_pos.cuda(local_rank)
            msk = msk.cuda(local_rank)

            model.zero_grad()
            res, loss, acc, nll, ppl, ntokens, npairs = model(
                xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg,
                ys_pos, msk)
            loss_acm += loss.item()
            acc_acm += acc
            nll_acm += nll
            ppl_acm += ppl
            ntokens_acm += ntokens
            npairs_acm += npairs
            nxs += npairs

            loss.backward()
            if args.world_size > 1:
                is_normal = average_gradients(model)
            else:
                is_normal = True
            if is_normal:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            else:
                print("gradient: none, gpu: " + str(local_rank), flush=True)
                continue
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.print_every == -1 % args.print_every:
                print ('batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'\
                        %(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, \
                        nll_acm/nxs, ppl_acm/nxs, npairs_acm, optimizer._rate), flush=True)
                acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.save_every == -1 % args.save_every:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)

                model.eval()
                eval_epoch(
                    args, model, vocab, local_rank, "epoch-" +
                    str(train_data.epoch_id) + "-acm-" + str(batch_acm))
                model.train()

                torch.save(
                    {
                        'args': args,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, '%s/epoch%d_batch_%d' %
                    (args.save_dir, train_data.epoch_id, batch_acm))