示例#1
0
def run(args, local_rank):
    """ Distributed Synchronous """
    torch.manual_seed(1234)
    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
    if (args.world_size == 1 or dist.get_rank() == 0):
        print ("vocab.size = %d"%vocab.size, flush=True)
    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\
                  args.num_heads, args.dropout, args.layers, args.smoothing, args.approx)
    if args.start_from is not None:
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
    model = model.cuda(local_rank)
   
    if args.world_size > 1:
        torch.manual_seed(1234 + dist.get_rank())
        random.seed(5678 + dist.get_rank())
    
    optimizer = Optim(model.embed_dim, args.lr, args.warmup_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.998), eps=1e-9))

    if args.start_from is not None:
        optimizer.load_state_dict(ckpt['optimizer'])

    #train_data = DataLoader(vocab, args.train_data+"0"+str(local_rank), args.batch_size, args.max_len, args.min_len)
    train_data = DataLoader(vocab, args.train_data, args.batch_size, args.max_len, args.min_len)
    batch_acm = 0
    acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.
    while True:
        model.train()
        for truth, inp, msk in train_data:
            batch_acm += 1
            truth = truth.cuda(local_rank)
            inp = inp.cuda(local_rank)
            msk = msk.cuda(local_rank)

            model.zero_grad()
            res, loss, acc, nll, ppl, ntokens, npairs = model(truth, inp, msk)
            loss_acm += loss.item()
            acc_acm += acc
            nll_acm += nll
            ppl_acm += ppl
            ntokens_acm += ntokens
            npairs_acm += npairs
            nxs += npairs
            
            loss.backward()
            if args.world_size > 1:
                average_gradients(model)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.print_every == -1%args.print_every:
                print ('batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'\
                        %(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, \
                        nll_acm/nxs, ppl_acm/nxs, npairs_acm, optimizer._rate), flush=True)
                acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.
            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.save_every == -1%args.save_every:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                torch.save({'args':args, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}, '%s/epoch%d_batch_%d'%(args.save_dir, train_data.epoch_id, batch_acm))
示例#2
0
                               train=stage == "train",
                               download=True,
                               transform=transforms.ToTensor())
        loader[stage] = torch.utils.data.DataLoader(mnist,
                                                    batch_size=batch_size,
                                                    shuffle=stage == "train",
                                                    **more)

with scope("Models Construction"):
    G = Generator()
    D = Discriminator()

with scope("Optimizer Builder"):
    hyper = dict(lr=lambda i: 0.1 / (1.00004)**i,
                 momentum=lambda i: 0.5 + min(0.2, i / 1e6))
    Gopt_generator = Optim(torch.optim.SGD, G.parameters(), **hyper)
    Dopt_generator = Optim(torch.optim.SGD, D.parameters(), **hyper)

with scope("Optimization Step"):

    def run(iteration, model, entry, is_training=True):
        if is_training: eval(model).train()
        else: eval(model).eval()

        x, y = entry
        n_batch = x.size(0)
        x_flat = x.flatten(1)
        y_hot = torch.zeros(n_batch, 10)
        y_hot[y] = 1
        z = torch.rand(n_batch, 100)
示例#3
0
def train():
    text = Text(config.src_corpus, config.tar_corpus)
    train_data = Data(config.train_path_src, config.train_path_tar)
    dev_data = Data(config.dev_path_src, config.dev_path_tar)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=config.batch_size,
                              shuffle=True,
                              collate_fn=utils.get_batch)
    dev_loader = DataLoader(dataset=dev_data,
                            batch_size=config.dev_batch_size,
                            shuffle=True,
                            collate_fn=utils.get_batch)
    parser = OptionParser()
    parser.add_option("--embed_size",
                      dest="embed_size",
                      default=config.embed_size)
    parser.add_option("--hidden_size",
                      dest="hidden_size",
                      default=config.hidden_size)
    parser.add_option("--window_size_d",
                      dest="window_size_d",
                      default=config.window_size_d)
    parser.add_option("--encoder_layer",
                      dest="encoder_layer",
                      default=config.encoder_layer)
    parser.add_option("--decoder_layers",
                      dest="decoder_layers",
                      default=config.decoder_layers)
    parser.add_option("--dropout_rate",
                      dest="dropout_rate",
                      default=config.dropout_rate)
    (options, args) = parser.parse_args()
    device = torch.device("cuda:0" if config.cuda else "cpu")
    #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_attention/result/01.31_drop0.3_54_21.46508598886769_checkpoint.pth"
    #print(f"load model from {model_path}", file=sys.stderr)
    #model = NMT.load(model_path)
    model = NMT(text, options, device)
    #model = model.cuda()
    #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_attention/result/140_164.29781984744628_checkpoint.pth"
    #print(f"load model from {model_path}", file=sys.stderr)
    #model = NMT.load(model_path)
    #model = torch.nn.DataParallel(model)
    model = model.to(device)
    model = model.cuda()
    model.train()
    optimizer = Optim(torch.optim.Adam(model.parameters()))
    #optimizer = Optim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), config.hidden_size, config.warm_up_step)
    #print(optimizer.lr)
    epoch = 0
    valid_num = 1
    hist_valid_ppl = []

    print("begin training!")
    while (True):
        epoch += 1
        max_iter = int(math.ceil(len(train_data) / config.batch_size))
        with tqdm(total=max_iter, desc="train") as pbar:
            for src_sents, tar_sents, tar_words_num_to_predict in train_loader:
                optimizer.zero_grad()
                batch_size = len(src_sents)

                now_loss = -model(src_sents, tar_sents)
                now_loss = now_loss.sum()
                loss = now_loss / batch_size
                loss.backward()

                _ = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   config.clip_grad)
                #optimizer.updata_lr()
                optimizer.step_and_updata_lr()

                pbar.set_postfix({
                    "epwwoch":
                    epoch,
                    "avg_loss":
                    loss.item(),
                    "ppl":
                    math.exp(now_loss.item() / tar_words_num_to_predict),
                    "lr":
                    optimizer.lr
                })
                #pbar.set_postfix({"epoch": epoch, "avg_loss": loss.item(), "ppl": math.exp(now_loss.item()/tar_words_num_to_predict)})
                pbar.update(1)
        #print(optimizer.lr)
        if (epoch % config.valid_iter == 0):
            #if (epoch >= config.valid_iter//2):
            if (valid_num % 5 == 0):
                valid_num = 0
                optimizer.updata_lr()
            valid_num += 1
            print("now begin validation ...", file=sys.stderr)
            eav_ppl = evaluate_ppl(model, dev_data, dev_loader)
            print("validation ppl %.2f" % (eav_ppl), file=sys.stderr)
            flag = len(hist_valid_ppl) == 0 or eav_ppl < min(hist_valid_ppl)
            if (flag):
                print("current model is the best!, save to [%s]" %
                      (config.model_save_path),
                      file=sys.stderr)
                hist_valid_ppl.append(eav_ppl)
                model.save(
                    os.path.join(
                        config.model_save_path,
                        f"02.08_window35drop0.2_{epoch}_{eav_ppl}_checkpoint.pth"
                    ))
                torch.save(
                    optimizer.optimizer.state_dict(),
                    os.path.join(
                        config.model_save_path,
                        f"02.08_window35drop0.2_{epoch}_{eav_ppl}_optimizer.optim"
                    ))
        if (epoch == config.max_epoch):
            print("reach the maximum number of epochs!", file=sys.stderr)
            return
示例#4
0
def main(hparams: HParams):
    '''
    setup training.
    '''
    if torch.cuda.is_available() and not hparams.gpus:
        warnings.warn(
            'WARNING: you have a CUDA device, so you should probably run with -gpus 0'
        )

    device = torch.device(hparams.gpus if torch.cuda.is_available() else 'cpu')

    # data setup
    print(f"Loading vocabulary...")
    text_preprocessor = TextPreprocessor.load(hparams.preprocessor_path)

    transform = transforms.Compose([
        transforms.Resize([hparams.img_size, hparams.img_size]),
        transforms.RandomCrop([hparams.crop_size, hparams.crop_size]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # create dataloader
    print('Creating DataLoader...')
    normal_data_loader = get_image_caption_loader(
        hparams.img_dir,
        hparams.normal_caption_path,
        text_preprocessor,
        hparams.normal_batch_size,
        transform,
        shuffle=True,
        num_workers=hparams.num_workers,
    )

    style_data_loader = get_caption_loader(
        hparams.style_caption_path,
        text_preprocessor,
        batch_size=hparams.style_batch_size,
        shuffle=True,
        num_workers=hparams.num_workers,
    )

    if hparams.train_from:
        # loading checkpoint
        print('Loading checkpoint...')
        checkpoint = torch.load(hparams.train_from)
    else:
        normal_opt = Optim(
            hparams.optimizer,
            hparams.normal_lr,
            hparams.max_grad_norm,
            hparams.lr_decay,
            hparams.start_decay_at,
        )
        style_opt = Optim(
            hparams.optimizer,
            hparams.style_lr,
            hparams.max_grad_norm,
            hparams.lr_decay,
            hparams.start_decay_at,
        )

    print('Building model...')
    encoder = EncoderCNN(hparams.hidden_dim)
    decoder = FactoredLSTM(hparams.embed_dim,
                           text_preprocessor.vocab_size,
                           hparams.hidden_dim,
                           hparams.style_dim,
                           hparams.num_layers,
                           hparams.random_init,
                           hparams.dropout_ratio,
                           train=True,
                           device=device)

    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=text_preprocessor.PAD_ID)
    normal_params = list(encoder.parameters()) + list(
        decoder.default_parameters())
    style_params = list(decoder.style_parameters())
    normal_opt.set_parameters(normal_params)
    style_opt.set_parameters(style_params)

    if hparams.train_from:
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        normal_opt.load_state_dict(checkpoint['normal_opt'])
        style_opt.load_state_dict(checkpoint['style_opt'])

    # traininig loop
    print('Start training...')
    for epoch in range(hparams.num_epoch):

        # result
        sum_normal_loss, sum_style_loss, sum_normal_ppl, sum_style_ppl = 0, 0, 0, 0

        # normal caption
        for i, (images, in_captions, out_captions,
                lengths) in enumerate(normal_data_loader):
            images = images.to(device)
            in_captions = in_captions.to(device)
            out_captions = out_captions.contiguous().view(-1).to(device)

            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(in_captions, features, mode='default')
            loss = criterion(outputs.view(-1, outputs.size(-1)), out_captions)
            encoder.zero_grad()
            decoder.zero_grad()
            loss.backward()
            normal_opt.step()

            # print log
            sum_normal_loss += loss.item()
            sum_normal_ppl += np.exp(loss.item())
            if i % hparams.normal_log_step == 0:
                print(
                    f'Epoch [{epoch}/{hparams.num_epoch}], Normal Step: [{i}/{len(normal_data_loader)}] '
                    f'Normal Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}'
                )

        # style caption
        for i, (in_captions, out_captions,
                lengths) in enumerate(style_data_loader):
            in_captions = in_captions.to(device)
            out_captions = out_captions.contiguous().view(-1).to(device)

            # Forward, backward and optimize
            outputs = decoder(in_captions, None, mode='style')
            loss = criterion(outputs.view(-1, outputs.size(-1)), out_captions)

            decoder.zero_grad()
            loss.backward()
            style_opt.step()

            sum_style_loss += loss.item()
            sum_style_ppl += np.exp(loss.item())
            # print log
            if i % hparams.style_log_step == 0:
                print(
                    f'Epoch [{epoch}/{hparams.num_epoch}], Style Step: [{i}/{len(style_data_loader)}] '
                    f'Style Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}'
                )

        model_params = {
            'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'epoch': epoch,
            'normal_opt': normal_opt.optimizer.state_dict(),
            'style_opt': style_opt.optimizer.state_dict(),
        }

        avg_normal_loss = sum_normal_loss / len(normal_data_loader)
        avg_style_loss = sum_style_loss / len(style_data_loader)
        avg_normal_ppl = sum_normal_ppl / len(normal_data_loader)
        avg_style_ppl = sum_style_ppl / len(style_data_loader)
        print(f'Epoch [{epoch}/{hparams.num_epoch}] statistics')
        print(
            f'Normal Loss: {avg_normal_loss:.4f} Normal ppl: {avg_normal_ppl:5.4f} '
            f'Style Loss: {avg_style_loss:.4f} Style ppl: {avg_style_ppl:5.4f}'
        )

        torch.save(
            model_params,
            f'{hparams.model_path}/n-loss_{avg_normal_loss:.4f}_s-loss_{avg_style_loss:.4f}_'
            f'n-ppl_{avg_normal_ppl:5.4f}_s-ppl_{avg_style_ppl:5.4f}_epoch_{epoch}.pt'
        )
示例#5
0
if args.cuda:
    netG.cuda()
    netD.cuda()
netG_params_count = sum([p.nelement() for p in netG.parameters()])
netD_params_count = sum([p.nelement() for p in netD.parameters()])
print(netG)
print(netD)
print('netG has number of parameters: %d' % (netG_params_count))
print('netD has number of parameters: %d' % (netD_params_count))
one = torch.cuda.FloatTensor([1])
mone = one * -1

# ========= Setup loss function and optimizer  =========#
optimizerG = Optim(netG.parameters(),
                   args.optim,
                   lr=args.lr,
                   grad_clip=args.grad_clip,
                   weight_decay=args.weight_decay,
                   momentum=args.momentum)

optimizerD = Optim(netD.parameters(),
                   args.optim,
                   lr=args.lr,
                   grad_clip=args.grad_clip,
                   weight_decay=args.weight_decay,
                   momentum=args.momentum)

# sigma for mixture of RBF kernel in MMD
#sigma_list = [1.0]
#sigma_list = mmd_util.median_heuristic(Data.Y_subspace, beta=1.)
sigma_list = mmd_util.median_heuristic(Data.Y_subspace, beta=.5)
sigma_var = torch.FloatTensor(sigma_list).cuda()
示例#6
0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset_path = './data/txt'
data_path = './data'

train_data = RawDataLoader(dataset_path, 20)
print_every = 50

# 现成的BERT用来做句子划分
bert_model = BertTextNet()

# 组合的大模型
model = BigModel(device)
model = model.to(device)

# 模型优化器使用BERT默认的
optimizer = Optim(model_size=768, factor=1, warmup=10000,\
    optimizer=torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.998), eps=1e-9))

batch_acm = 0
acc1_acm = 0.
acc2_acm = 0.
loss_acm = 0.
stoa_acc = 0.

while True:
    model.train()
    # 使用数据集,每次读取20条句子
    for sentences in train_data:
        batch_acm += 1

        tokens, segments, input_masks = [], [], []
        for sentence in sentences:
示例#7
0
# ##############################################################################
# Build model
# ##############################################################################
import model
from const import PAD
from optim import Optim

encode = model.Encode(use_cuda)
actor = model.Actor(args.vocab_size, args.dec_hsz, args.rnn_layers,
                    args.batch_size, args.max_len, args.dropout, use_cuda)

critic = model.Critic(args.vocab_size, args.dec_hsz, args.rnn_layers,
                      args.batch_size, args.max_len, args.dropout, use_cuda)

optim_pre_A = Optim(actor.parameters(), args.pre_lr, True)
optim_pre_C = Optim(critic.parameters(), args.pre_lr, True)

optim_A = Optim(actor.parameters(), args.lr, False, args.new_lr)
optim_C = Optim(critic.parameters(), args.lr, False, args.new_lr)

criterion_A = torch.nn.CrossEntropyLoss(ignore_index=PAD)
criterion_C = torch.nn.MSELoss()

if use_cuda:
    actor = actor.cuda()
    critic = critic.cuda()

# ##############################################################################
# Training
# ##############################################################################
示例#8
0
def train(index):
    torch.manual_seed(1)
    if (config.cuda):
        torch.cuda.manual_seed(1)
    device = torch.device(f"cuda:{index}" if config.cuda else "cpu")
    dist_rank = index
    torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=dist_rank, world_size=1)
    is_master_node = (dist_rank == 0)
    
    args = dict()
    args['embed_size'] = config.embed_size
    args['d_model'] = config.d_model
    args['nhead'] = config.nhead
    args['num_encoder_layers'] = config.num_encoder_layers
    args['num_decoder_layers'] = config.num_decoder_layers
    args['dim_feedforward'] = config.dim_feedforward
    args['dropout'] = config.dropout
    args['smoothing_eps'] = config.smoothing_eps
    
    text = Text(config.src_corpus, config.tar_corpus)
    model = NMT(text, args, device)
    model = make_data_parallel(model, device)
    
    train_data = Data(config.train_path_src, config.train_path_tar)
    dev_data = Data(config.dev_path_src, config.dev_path_tar)
    train_sampler = DistributedSampler(train_data)
    dev_sampler = DistributedSampler(dev_data)
    train_loader = DataLoader(dataset=train_data, batch_size=int(config.train_batch_size/8), shuffle=False, num_workers=9, pin_memory=True, sampler=train_sampler, collate_fn=utils.get_batch)
    dev_loader = DataLoader(dataset=dev_data, batch_size=int(config.dev_batch_size/8), shuffle=False, num_workers=9, pin_memory=True, sampler=dev_sampler, collate_fn=utils.get_batch)

    model.train()
    optimizer = Optim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), config.d_model, config.warm_up_step)

    epoch = 0
    history_valid_ppl = []
    print("begin training!", file=sys.stderr)
    while (True):
        epoch += 1
        train_loader.sampler.set_epoch(epoch)
        max_iter = int(math.ceil(len(train_data)/config.train_batch_size))
        with tqdm(total=max_iter, desc="train") as pbar:
            for batch_src, batch_tar, tar_word_num in train_loader:
                optimizer.zero_grad()
                now_batch_size = len(batch_src)
                batch_loss = -model(batch_src, batch_tar, smoothing=True)
                batch_loss = batch_loss.sum()
                loss = batch_loss / now_batch_size
                loss.backward()
                torch.distributed.barrier()
                optimizer.step_and_updata_lr()
                if (is_master_node):
                    pbar.set_postfix({"epoch": epoch, "avg_loss": '{%.2f}' % (loss.item()), "ppl": '{%.2f}' % (batch_loss.item()/tar_word_num)})
                    pbar.update(1)
        if (epoch % config.valid_iter == 0):
            print("now begin validation...", file=sys.stderr)
            torch.distributed.barrier()
            eval_ppl = evaluate_ppl(model, dev_data, dev_loader, config.dev_batch_size, is_master_node)
            print(eval_ppl)
            flag = len(history_valid_ppl) == 0 or eval_ppl < min(history_valid_ppl)
            if (flag):
                print(f"current model is the best! save to [{config.model_save_path}]", file=sys.stderr)
                history_valid_ppl.append(eval_ppl)
                model.save(os.path.join(config.model_save_path, f"02.19_{epoch}_{eval_ppl}_checkpoint.pth"))
                torch.save(optimizer.optimizer.state_dict(), os.path.join(config.model_save_path, f"02.19_{epoch}_{eval_ppl}_optimizer.optim"))
        if (epoch == config.max_epoch):
            print("reach the maximum number of epochs!", file=sys.stderr)
            return
示例#9
0
def run(args, local_rank):
    """ Distributed Synchronous """
    torch.manual_seed(1234)
    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
    if (args.world_size == 1 or dist.get_rank() == 0):
        print(vocab.size, flush=True)
    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\
                  args.num_heads, args.dropout, args.layers, args.smoothing, args.approx)
    if args.start_from is not None:
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
    model = model.cuda(local_rank)

    weight_decay_params = []
    no_weight_decay_params = []

    for name, param in model.named_parameters():
        if name.endswith('bias') or 'layer_norm' in name:
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
    grouped_params = [{
        'params': weight_decay_params,
        'weight_decay': args.weight_decay
    }, {
        'params': no_weight_decay_params,
        'weight_decay': 0.
    }]
    if args.world_size > 1:
        torch.manual_seed(1234 + dist.get_rank())
        random.seed(5678 + dist.get_rank())

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        optimizer = FusedAdam(grouped_params,
                              lr=args.lr,
                              betas=(0.9, 0.999),
                              eps=1e-6,
                              bias_correction=False,
                              max_grad_norm=1.0)
        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)

    else:
        if args.weight_decay > 0:
            optimizer = AdamWeightDecayOptimizer(grouped_params,
                                                 lr=args.lr,
                                                 betas=(0.9, 0.999),
                                                 eps=1e-6)
        else:
            optimizer = Optim(
                model.embed_dim, args.lr, args.warmup_steps,
                torch.optim.Adam(model.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.998),
                                 eps=1e-9))

    if args.start_from is not None:
        optimizer.load_state_dict(ckpt['optimizer'])

    train_data = DataLoader(vocab, args.train_data, args.batch_size,
                            args.max_len, args.min_len)
    batch_acm = 0
    acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.
    while True:
        model.train()
        if train_data.epoch_id > 30:
            break
        for xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk in train_data:
            batch_acm += 1
            xs_tpl = xs_tpl.cuda(local_rank)
            xs_seg = xs_seg.cuda(local_rank)
            xs_pos = xs_pos.cuda(local_rank)
            ys_truth = ys_truth.cuda(local_rank)
            ys_inp = ys_inp.cuda(local_rank)
            ys_tpl = ys_tpl.cuda(local_rank)
            ys_seg = ys_seg.cuda(local_rank)
            ys_pos = ys_pos.cuda(local_rank)
            msk = msk.cuda(local_rank)

            model.zero_grad()
            res, loss, acc, nll, ppl, ntokens, npairs = model(
                xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg,
                ys_pos, msk)
            loss_acm += loss.item()
            acc_acm += acc
            nll_acm += nll
            ppl_acm += ppl
            ntokens_acm += ntokens
            npairs_acm += npairs
            nxs += npairs
            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()
            if args.world_size > 1:
                is_normal = average_gradients(model)
            else:
                is_normal = True
            if is_normal:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            else:
                print("gradient: none, gpu: " + str(local_rank), flush=True)
                continue
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.print_every == -1 % args.print_every:
                print ('batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'\
                        %(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, \
                        nll_acm/nxs, ppl_acm/nxs, npairs_acm, optimizer._rate), flush=True)
                acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.save_every == -1 % args.save_every:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                torch.save(
                    {
                        'args': args,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, '%s/epoch%d_batch_%d' %
                    (args.save_dir, train_data.epoch_id, batch_acm))
示例#10
0
def main():

    print("Loading data from '%s'" % opt.data)

    dataset = torch.load(opt.data)

    dict_checkpoint = opt.train_from if opt.train_from else opt.train_from_state_dict
    if dict_checkpoint:
        print('Loading dicts from checkpoint at %s' % dict_checkpoint)
        checkpoint = torch.load(dict_checkpoint)
        dataset['dicts'] = checkpoint['dicts']

    trainData = Dataset(dataset['train']['src'], dataset['train']['tgt'],
                        opt.batch_size, opt.gpus)
    validData = Dataset(dataset['valid']['src'],
                        dataset['valid']['tgt'],
                        opt.batch_size,
                        opt.gpus,
                        volatile=True)

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (len(dicts["word2index"]['src']), len(dicts["word2index"]['tgt'])))
    print(' * number of training sentences. %d' % len(dataset['train']['src']))
    print(' * maximum batch size. %d' % opt.batch_size)

    print('Building model...')

    encoder = Encoder(opt, len(dicts["word2index"]['src']))
    decoder = Decoder(opt, len(dicts["word2index"]['tgt']))

    generator = nn.Sequential(
        nn.Linear(opt.hidden_size * 2, len(dicts["word2index"]['tgt'])),
        nn.LogSoftmax())

    model = NMTModel(encoder, decoder)

    if opt.train_from:
        print('Loading model from checkpoint at %s' % opt.train_from)
        chk_model = checkpoint['model']
        generator_state_dict = chk_model.generator.state_dict()
        model_state_dict = {
            k: v
            for k, v in chk_model.state_dict().items() if 'generator' not in k
        }
        model.load_state_dict(model_state_dict)
        generator.load_state_dict(generator_state_dict)
        opt.start_epoch = checkpoint['epoch'] + 1

    if opt.train_from_state_dict:
        print('Loading model from checkpoint at %s' %
              opt.train_from_state_dict)
        model.load_state_dict(checkpoint['model'])
        generator.load_state_dict(checkpoint['generator'])
        opt.start_epoch = checkpoint['epoch'] + 1

    if len(opt.gpus) >= 1:
        model.cuda()
        generator.cuda()
    else:
        model.cpu()
        generator.cpu()

    if len(opt.gpus) > 1:
        model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
        generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

    model.generator = generator

    if not opt.train_from_state_dict and not opt.train_from:
        for p in model.parameters():
            p.data.uniform_(-opt.param_init, opt.param_init)

        encoder.load_pretrained_vectors(opt)
        decoder.load_pretrained_vectors(opt)

        optim = Optim(opt.optim,
                      opt.learning_rate,
                      opt.max_grad_norm,
                      lr_decay=opt.learning_rate_decay,
                      start_decay_at=opt.start_decay_at)
    else:
        print('Loading optimizer from checkpoint:')
        optim = checkpoint['optim']
        print(optim)

    optim.set_parameters(model.parameters())

    if opt.train_from or opt.train_from_state_dict:
        optim.optimizer.load_state_dict(
            checkpoint['optim'].optimizer.state_dict())

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)

    criterion = NMTCriterion(len(dicts["word2index"]['tgt']))

    trainModel(model, trainData, validData, dataset, optim, criterion)
示例#11
0
                              evaluation=True)

# ##############################################################################
# Build model
# ##############################################################################
import model
from const import PAD
from optim import Optim

actor = model.Actor(args.vocab_size, args.dec_hsz, args.rnn_layers,
                    args.batch_size, args.max_len, args.dropout, use_cuda)

critic = model.Critic(args.vocab_size, args.dec_hsz, args.rnn_layers,
                      args.batch_size, args.max_len, args.dropout, use_cuda)

optim_pre_A = Optim(actor.get_trainable_parameters(), args.pre_lr, True,
                    args.grad_clip)
optim_pre_C = Optim(critic.parameters(), args.pre_lr, True, args.grad_clip)

optim_A = Optim(actor.get_trainable_parameters(), args.lr, False, args.new_lr,
                args.grad_clip)
optim_C = Optim(critic.parameters(), args.lr, False, args.new_lr,
                args.grad_clip)

criterion_A = torch.nn.CrossEntropyLoss(ignore_index=PAD)
criterion_C = torch.nn.MSELoss()
criterion_AC = model.RewardCriterion()

if use_cuda:
    actor = actor.cuda()
    critic = critic.cuda()
示例#12
0
def run(args, local_rank):
    """ Distributed Synchronous """
    torch.manual_seed(1234)
    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
    if (args.world_size == 1 or dist.get_rank() == 0):
        print("vocab.size = " + str(vocab.size), flush=True)
    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\
                  args.num_heads, args.dropout, args.layers, args.smoothing)
    if args.start_from is not None:
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
    model = model.cuda(local_rank)

    optimizer = Optim(
        model.embed_dim, args.lr, args.warmup_steps,
        torch.optim.Adam(model.parameters(),
                         lr=0,
                         betas=(0.9, 0.998),
                         eps=1e-9))

    if args.start_from is not None:
        optimizer.load_state_dict(ckpt['optimizer'])

    train_data = DataLoader(vocab, args.train_data, args.batch_size,
                            args.max_len, args.min_len)
    batch_acm = 0
    acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.
    while True:
        model.train()
        if train_data.epoch_id > 30:
            break
        for xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk in train_data:
            batch_acm += 1
            xs_tpl = xs_tpl.cuda(local_rank)
            xs_seg = xs_seg.cuda(local_rank)
            xs_pos = xs_pos.cuda(local_rank)
            ys_truth = ys_truth.cuda(local_rank)
            ys_inp = ys_inp.cuda(local_rank)
            ys_tpl = ys_tpl.cuda(local_rank)
            ys_seg = ys_seg.cuda(local_rank)
            ys_pos = ys_pos.cuda(local_rank)
            msk = msk.cuda(local_rank)

            model.zero_grad()
            res, loss, acc, nll, ppl, ntokens, npairs = model(
                xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg,
                ys_pos, msk)
            loss_acm += loss.item()
            acc_acm += acc
            nll_acm += nll
            ppl_acm += ppl
            ntokens_acm += ntokens
            npairs_acm += npairs
            nxs += npairs

            loss.backward()
            if args.world_size > 1:
                is_normal = average_gradients(model)
            else:
                is_normal = True
            if is_normal:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            else:
                print("gradient: none, gpu: " + str(local_rank), flush=True)
                continue
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.print_every == -1 % args.print_every:
                print ('batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'\
                        %(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, \
                        nll_acm/nxs, ppl_acm/nxs, npairs_acm, optimizer._rate), flush=True)
                acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.save_every == -1 % args.save_every:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)

                model.eval()
                eval_epoch(
                    args, model, vocab, local_rank, "epoch-" +
                    str(train_data.epoch_id) + "-acm-" + str(batch_acm))
                model.train()

                torch.save(
                    {
                        'args': args,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, '%s/epoch%d_batch_%d' %
                    (args.save_dir, train_data.epoch_id, batch_acm))
示例#13
0
def train():
    torch.manual_seed(1)
    if (config.cuda):
        torch.cuda.manual_seed(1)
    args = dict()
    args['embed_size'] = config.embed_size
    args['d_model'] = config.d_model
    args['nhead'] = config.nhead
    args['num_encoder_layers'] = config.num_encoder_layers
    args['num_decoder_layers'] = config.num_decoder_layers
    args['dim_feedforward'] = config.dim_feedforward
    args['dropout'] = config.dropout
    args['smoothing_eps'] = config.smoothing_eps
    text = Text(config.src_corpus, config.tar_corpus)
    train_data = Data(config.train_path_src, config.train_path_tar)
    dev_data = Data(config.dev_path_src, config.dev_path_tar)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=config.train_batch_size,
                              shuffle=True,
                              collate_fn=utils.get_batch)
    dev_loader = DataLoader(dataset=dev_data,
                            batch_size=config.dev_batch_size,
                            shuffle=True,
                            collate_fn=utils.get_batch)
    #train_data_src, train_data_tar = utils.read_corpus(config.train_path)
    #dev_data_src, dev_data_tar = utils.read_corpus(config.dev_path)
    device = torch.device("cuda:0" if config.cuda else "cpu")
    model = NMT(text, args, device)
    #model = nn.DataParallel(model, device_ids=[0, 1])
    model = model.to(device)
    #model = model.module
    #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_transformer/result/02.01_1_344.6820465077113_checkpoint.pth"
    #model = NMT.load(model_path)
    #model = model.to(device)
    model.train()
    optimizer = Optim(
        torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9),
        config.d_model, config.warm_up_step)
    #optimizer = Optim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), config.warm_up_step, config.init_lr, config.lr)
    #optimizer = Optim(torch.optim.Adam(model.parameters()))

    epoch = 0
    history_valid_ppl = []
    print("begin training!", file=sys.stderr)
    while (True):
        epoch += 1
        max_iter = int(math.ceil(len(train_data) / config.train_batch_size))
        with tqdm(total=max_iter, desc="train") as pbar:
            #for batch_src, batch_tar, tar_word_num in utils.batch_iter(train_data_src, train_data_tar, config.train_batch_size):
            for batch_src, batch_tar, tar_word_num in train_loader:
                optimizer.zero_grad()
                now_batch_size = len(batch_src)
                batch_loss = -model(batch_src, batch_tar, smoothing=True)
                batch_loss = batch_loss.sum()
                loss = batch_loss / now_batch_size
                loss.backward()
                #optimizer.step()
                #optimizer.updata_lr()
                optimizer.step_and_updata_lr()
                pbar.set_postfix({
                    "epoch":
                    epoch,
                    "avg_loss":
                    '{%.2f}' % (loss.item()),
                    "ppl":
                    '{%.2f}' % (math.exp(batch_loss.item() / tar_word_num))
                })
                pbar.update(1)
        if (epoch % config.valid_iter == 0):
            print("now begin validation...", file=sys.stderr)
            eval_ppl = evaluate_ppl(model, dev_data, dev_loader,
                                    config.dev_batch_size)
            print(eval_ppl)
            flag = len(
                history_valid_ppl) == 0 or eval_ppl < min(history_valid_ppl)
            if (flag):
                print(
                    f"current model is the best! save to [{config.model_save_path}]",
                    file=sys.stderr)
                history_valid_ppl.append(eval_ppl)
                model.save(
                    os.path.join(config.model_save_path,
                                 f"02.10_{epoch}_{eval_ppl}_checkpoint.pth"))
                torch.save(
                    optimizer.optimizer.state_dict(),
                    os.path.join(config.model_save_path,
                                 f"02.10_{epoch}_{eval_ppl}_optimizer.optim"))
        if (epoch == config.max_epoch):
            print("reach the maximum number of epochs!", file=sys.stderr)
            return
示例#14
0
# ========= Setup loss function and optimizer  =========#
if args.loss == 'L1':
    criterion = nn.L1Loss(size_average=True)
elif args.loss == 'L2':
    criterion = nn.MSELoss(size_average=True)
elif args.loss == 'Huber':
    criterion = nn.SmoothL1Loss(size_average=True)
else:
    raise NotImplementedError(
        'Loss function %s is not support! Consider L1|L2|Huber' %
        (args.loss_func))
if args.cuda:
    cirterion = criterion.cuda()
optimizer = Optim(model.parameters(),
                  args.optim,
                  lr=args.lr,
                  grad_clip=args.grad_clip,
                  weight_decay=args.weight_decay,
                  momentum=args.momentum)

# sigma for mixture of RBF kernel in MMD
#sigma_list = [1.0]
#sigma_list = mmd_util.median_heuristic(Data.Y_subspace, beta=1.)
sigma_list = mmd_util.median_heuristic(Data.Y_subspace, beta=.5)
sigma_var = torch.FloatTensor(sigma_list).cuda()
print('sigma_list:', sigma_var)

# ========= Main loop for pretraining RNN with reconstruction loss  =========#
Y_val = Data.val_set['Y'].numpy()
L_val = Data.val_set['L'].numpy()
Y_tst = Data.tst_set['Y'].numpy()
L_tst = Data.tst_set['L'].numpy()
示例#15
0
def run(args, local_rank):
    """ Distributed Synchronous """
    torch.manual_seed(1234)
    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
    if (args.world_size == 1 or dist.get_rank() == 0):
        print("vocab.size = " + str(vocab.size), flush=True)
    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,
                  args.num_heads, args.dropout, args.layers, args.smoothing)
    if args.start_from is not None:
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
    model = model.cuda(local_rank)

    optimizer = Optim(
        model.embed_dim, args.lr, args.warmup_steps,
        torch.optim.Adam(model.parameters(),
                         lr=0,
                         betas=(0.9, 0.998),
                         eps=1e-9))

    if args.start_from is not None:
        optimizer.load_state_dict(ckpt['optimizer'])

    train_data = DataLoader(vocab, args.train_data, args.batch_size,
                            args.max_len, args.min_len)
    batch_acm = 0
    acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.
    while True:
        model.train()
        if train_data.epoch_id > args.max_epoch:
            break
        for xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk in train_data:
            batch_acm += 1
            xs_tpl = xs_tpl.cuda(local_rank)
            xs_seg = xs_seg.cuda(local_rank)
            xs_pos = xs_pos.cuda(local_rank)
            ys_truth = ys_truth.cuda(local_rank)
            ys_inp = ys_inp.cuda(local_rank)
            ys_tpl = ys_tpl.cuda(local_rank)
            ys_seg = ys_seg.cuda(local_rank)
            ys_pos = ys_pos.cuda(local_rank)
            msk = msk.cuda(local_rank)

            model.zero_grad()
            res, loss, acc, nll, ppl, ntokens, npairs = model(
                xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg,
                ys_pos, msk)

            # http://www.myzaker.com/article/5f3747a28e9f096c723a65e0/ 资料
            # 常用的文本生成评测指标 PPL、Distinct 外,
            # 本文还专门设计了衡量格式(Format)准确率、韵律(Rhyme)准确率和句子完整性(integrity)的指标。
            # 格式(Format)准确率: Precision p、Recall r 和 F1 得分-> Macro-F1 和 Micro-F1
            # 完整性有个奇怪的log值
            # 传统的BLEU和ROUGE, 再songnet中完全用不到, 创作要求多样性
            loss_acm += loss.item()  # 损失
            acc_acm += acc  # 精确度
            nll_acm += nll  #
            ppl_acm += ppl  # -log 和, 其实就是句子出现的概率, 越小, 困惑度越高
            # 新指标, 困惑度perplexity, 比较两者再预测样本上的优劣, 困惑都越低越好??, 咋定义的
            ntokens_acm += ntokens  # 字符数
            npairs_acm += npairs  # 句子?
            nxs += npairs

            # 为什么啊, 感觉好难啊gpt2

            loss.backward()
            if args.world_size > 1:
                is_normal = average_gradients(model)
            else:
                is_normal = True
            if is_normal:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            else:
                print("gradient: none, gpu: " + str(local_rank), flush=True)
                continue
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.print_every == -1 % args.print_every:
                today = datetime.datetime.now()
                print(today)
                print(
                    'batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'
                    % (batch_acm, loss_acm / args.print_every,
                       acc_acm / ntokens_acm, nll_acm / nxs, ppl_acm / nxs,
                       npairs_acm, optimizer._rate),
                    flush=True)
                acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.
            if (args.world_size == 1 or dist.get_rank() == 0
                ) and batch_acm % args.save_every == -1 % args.save_every:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)

                model.eval()
                eval_epoch(
                    args, model, vocab, local_rank, "epoch-" +
                    str(train_data.epoch_id) + "-acm-" + str(batch_acm))
                model.train()

                torch.save(
                    {
                        'args': args,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, '%s/epoch%d_batch_%d' %
                    (args.save_dir, train_data.epoch_id, batch_acm))
def main():
    args = make_parser().parse_args()
    print("[Model hyperparams]: {}".format(str(args)))

    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cpu") if not cuda else torch.device("cuda:" +
                                                               str(args.gpu))
    seed_everything(seed=1337, cuda=cuda)
    vectors = None  #don't use pretrained vectors
    # vectors = load_pretrained_vectors(args.emsize)

    # Load dataset iterators
    iters, TEXT, LABEL, PROMPTS = dataset_map[args.data](
        args.batch_size,
        device=device,
        vectors=vectors,
        base_path=args.base_path)

    # Some datasets just have the train & test sets, so we just pretend test is valid
    if len(iters) >= 4:
        train_iter = iters[0]
        val_iter = iters[1]
        test_iter = iters[2]
        outdomain_test_iter = list(iters[3:])
    elif len(iters) == 3:
        train_iter, val_iter, test_iter = iters
    else:
        train_iter, test_iter = iters
        val_iter = test_iter

    print("[Corpus]: train: {}, test: {}, vocab: {}, labels: {}".format(
        len(train_iter.dataset), len(test_iter.dataset), len(TEXT.vocab),
        len(LABEL.vocab)))

    if args.model == "CNN":
        args.embed_num = len(TEXT.vocab)
        args.nlabels = len(LABEL.vocab)
        args.nprompts = len(PROMPTS.vocab)
        args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
        args.embed_dim = args.emsize
        classifier_model = CNN_Text_GANLike(args)
        topic_decoder = [
            nn.Sequential(
                nn.Linear(
                    len(args.kernel_sizes) * args.kernel_num, args.nprompts))
        ]

    else:
        ntokens, nlabels, nprompts = len(TEXT.vocab), len(LABEL.vocab), len(
            PROMPTS.vocab)
        args.nlabels = nlabels  # hack to not clutter function arguments
        args.nprompts = nprompts
        embedding = nn.Embedding(ntokens, args.emsize, padding_idx=1)
        encoder = Encoder(args.emsize,
                          args.hidden,
                          nlayers=args.nlayers,
                          dropout=args.drop,
                          bidirectional=args.bi,
                          rnn_type=args.rnn_model)

        attention_dim = args.hidden if not args.bi else 2 * args.hidden
        attention = BahdanauAttention(attention_dim, attention_dim)

        if args.bottleneck_dim == 0:
            classifier_model = Classifier_GANLike(embedding, encoder,
                                                  attention, attention_dim,
                                                  nlabels)
            topic_decoder = [
                nn.Sequential(nn.Dropout(args.topic_drop),
                              nn.Linear(attention_dim, args.nprompts))
            ]
        else:
            classifier_model = Classifier_GANLike_bottleneck(
                embedding,
                encoder,
                attention,
                attention_dim,
                nlabels,
                bottleneck_dim=args.bottleneck_dim)
            topic_decoder = [
                nn.Sequential(nn.Dropout(args.topic_drop),
                              nn.Linear(args.bottleneck_dim, args.nprompts))
            ]

    classifier_model.to(device)
    topic_decoder[0].to(device)

    classify_criterion = nn.CrossEntropyLoss()
    topic_criterion = nn.CrossEntropyLoss()

    classify_optim = Optim(args.optim, args.lr, args.clip)
    topic_optim = Optim(args.optim, args.lr, args.clip)

    for p in classifier_model.parameters():
        if not p.requires_grad:
            print("OMG", p)
            p.requires_grad = True
        p.data.uniform_(-args.param_init, args.param_init)

    for p in topic_decoder[0].parameters():
        if not p.requires_grad:
            print("OMG", p)
            p.requires_grad = True
        p.data.uniform_(-args.param_init, args.param_init)

    classify_optim.set_parameters(classifier_model.parameters())
    topic_optim.set_parameters(topic_decoder[0].parameters())

    if args.load:
        if args.latest:
            best_model = torch.load(args.save_dir + "/" + args.model_name +
                                    "_latestmodel")
        else:
            best_model = torch.load(args.save_dir + "/" + args.model_name +
                                    "_bestmodel")
    else:
        try:
            best_valid_loss = None
            best_model = None

            #pretraining the classifier
            for epoch in range(1, args.pretrain_epochs + 1):
                pretrain_classifier(classifier_model, train_iter,
                                    classify_optim, classify_criterion, args,
                                    epoch)
                loss = evaluate(classifier_model, val_iter, classify_criterion,
                                args)
                # oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest")

                if not best_valid_loss or loss < best_valid_loss:
                    best_valid_loss = loss
                    print("Updating best pretrained_model")
                    best_model = copy.deepcopy(classifier_model)
                    torch.save(
                        best_model, args.save_dir + "/" + args.model_name +
                        "_pretrained_bestmodel")
                torch.save(
                    classifier_model, args.save_dir + "/" + args.model_name +
                    "_pretrained_latestmodel")

            # classifier_model = best_model
            #alternating training like GANs
            for epoch in range(1, args.epochs + 1):
                for t_step in range(1, args.t_steps + 1):
                    train_topic_predictor(classifier_model, topic_decoder[-1],
                                          train_iter, topic_optim,
                                          topic_criterion, args, epoch,
                                          args.t_steps)

                if args.reset_classifier:
                    for p in classifier_model.parameters():
                        if not p.requires_grad:
                            print("OMG", p)
                            p.requires_grad = True
                        p.data.uniform_(-args.param_init, args.param_init)

                for c_step in range(1, args.c_steps + 1):
                    train_classifier(classifier_model, topic_decoder,
                                     train_iter, classify_optim,
                                     classify_criterion, topic_criterion, args,
                                     epoch, args.c_steps)
                    loss = evaluate(classifier_model, val_iter,
                                    classify_criterion, args)
                    # oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest")

                #creating a new instance of a decoder
                if args.model == "CNN":
                    topic_decoder += [
                        nn.Sequential(
                            nn.Linear(
                                len(args.kernel_sizes) * args.kernel_num,
                                args.nprompts))
                    ]
                else:
                    attention_dim = args.hidden if not args.bi else 2 * args.hidden
                    if args.bottleneck_dim == 0:
                        topic_decoder.append(
                            nn.Sequential(
                                nn.Dropout(args.topic_drop),
                                nn.Linear(attention_dim, args.nprompts)))
                    else:
                        topic_decoder.append(
                            nn.Sequential(
                                nn.Dropout(args.topic_drop),
                                nn.Linear(args.bottleneck_dim, args.nprompts)))

                #attaching a new optimizer to the new topic decode
                topic_decoder[-1].to(device)
                topic_optim = Optim(args.optim, args.lr, args.clip)
                for p in topic_decoder[-1].parameters():
                    if not p.requires_grad:
                        print("OMG", p)
                        p.requires_grad = True
                    p.data.uniform_(-args.param_init, args.param_init)
                topic_optim.set_parameters(topic_decoder[-1].parameters())

                if not best_valid_loss or loss < best_valid_loss:
                    best_valid_loss = loss
                    print("Updating best model")
                    best_model = copy.deepcopy(classifier_model)
                    torch.save(
                        best_model,
                        args.save_dir + "/" + args.model_name + "_bestmodel")
                torch.save(
                    classifier_model,
                    args.save_dir + "/" + args.model_name + "_latestmodel")

        except KeyboardInterrupt:
            print("[Ctrl+C] Training stopped!")

    # if not args.load:
    trainloss = evaluate(best_model,
                         train_iter,
                         classify_criterion,
                         args,
                         datatype='train',
                         writetopics=args.save_output_topics,
                         itos=TEXT.vocab.itos,
                         litos=LABEL.vocab.itos)
    valloss = evaluate(best_model,
                       val_iter,
                       classify_criterion,
                       args,
                       datatype='valid',
                       writetopics=args.save_output_topics,
                       itos=TEXT.vocab.itos,
                       litos=LABEL.vocab.itos)

    loss = evaluate(best_model,
                    test_iter,
                    classify_criterion,
                    args,
                    datatype='test',
                    writetopics=args.save_output_topics,
                    itos=TEXT.vocab.itos,
                    litos=LABEL.vocab.itos)
    if args.data == "AMAZON":
        oodnames = args.oodname.split(",")
        for oodname, oodtest_iter in zip(oodnames, outdomain_test_iter):
            oodLoss = evaluate(best_model,
                               oodtest_iter,
                               classify_criterion,
                               args,
                               datatype=oodname + "_bestmodel",
                               writetopics=args.save_output_topics)
            oodLoss = evaluate(classifier_model,
                               oodtest_iter,
                               classify_criterion,
                               args,
                               datatype=oodname + "_latest",
                               writetopics=args.save_output_topics)
    else:
        oodLoss = evaluate(best_model,
                           outdomain_test_iter[0],
                           classify_criterion,
                           args,
                           datatype="oodtest_bestmodel",
                           writetopics=args.save_output_topics,
                           itos=TEXT.vocab.itos,
                           litos=LABEL.vocab.itos)
        oodLoss = evaluate(classifier_model,
                           outdomain_test_iter[0],
                           classify_criterion,
                           args,
                           datatype="oodtest_latest",
                           writetopics=args.save_output_topics)
def main():
  args = make_parser().parse_args()
  print("[Model hyperparams]: {}".format(str(args)))

  cuda = torch.cuda.is_available() and args.cuda
  device = torch.device("cpu") if not cuda else torch.device("cuda:"+str(args.gpu))
  seed_everything(seed=1337, cuda=cuda)
  vectors = None #don't use pretrained vectors
  # vectors = load_pretrained_vectors(args.emsize)

  # Load dataset iterators
  if args.data in ["RT_GENDER"]:
    if args.finetune:
      iters, TEXT, LABEL, INDEX = make_rt_gender(args.batch_size, base_path=args.base_path, train_file=args.train_file, valid_file=args.valid_file, test_file=args.test_file, device=device, vectors=vectors, topics=False)
    else:
      iters, TEXT, LABEL, TOPICS, INDEX = make_rt_gender(args.batch_size, base_path=args.base_path, train_file=args.train_file, valid_file=args.valid_file, test_file=args.test_file, device=device, vectors=vectors, topics=True)
    train_iter, val_iter, test_iter = iters
  else:
    assert False

  if not args.finetune:
    for batch in train_iter:
      args.num_topics = batch.topics.shape[1]
      break

  print("[Corpus]: train: {}, test: {}, vocab: {}, labels: {}".format(
            len(train_iter.dataset), len(test_iter.dataset), len(TEXT.vocab), len(LABEL.vocab)))

  if args.model == "CNN":
    args.embed_num = len(TEXT.vocab)
    args.nlabels = len(LABEL.vocab)
    args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
    args.embed_dim = args.emsize
    classifier_model = CNN_Text_GANLike(args)
    topic_decoder = nn.Sequential(nn.Linear(len(args.kernel_sizes)*args.kernel_num, args.num_topics), nn.LogSoftmax(dim=-1))

  else:
    ntokens, nlabels = len(TEXT.vocab), len(LABEL.vocab)
    args.nlabels = nlabels # hack to not clutter function arguments

    embedding = nn.Embedding(ntokens, args.emsize, padding_idx=1)
    encoder = Encoder(args.emsize, args.hidden, nlayers=args.nlayers,
                      dropout=args.drop, bidirectional=args.bi, rnn_type=args.rnn_model)

    attention_dim = args.hidden if not args.bi else 2*args.hidden
    attention = BahdanauAttention(attention_dim, attention_dim)

    if args.bottleneck_dim == 0:
      classifier_model = Classifier_GANLike(embedding, encoder, attention, attention_dim, nlabels)
      topic_decoder = [nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(attention_dim, args.num_topics), nn.LogSoftmax())]
    else:
      classifier_model = Classifier_GANLike_bottleneck(embedding, encoder, attention, attention_dim, nlabels, bottleneck_dim=args.bottleneck_dim)
      topic_decoder = [nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(args.bottleneck_dim, args.num_topics), nn.LogSoftmax())]


  classifier_model.to(device)
  topic_decoder[0].to(device)

  classify_criterion = nn.CrossEntropyLoss()
  topic_criterion = nn.KLDivLoss(size_average=False)

  classify_optim = Optim(args.optim, args.lr, args.clip)
  topic_optim = Optim(args.optim, args.lr, args.clip)

  for p in classifier_model.parameters():
    if not p.requires_grad:
      print ("OMG", p)
      p.requires_grad = True
    p.data.uniform_(-args.param_init, args.param_init)

  for p in topic_decoder[0].parameters():
    if not p.requires_grad:
      print ("OMG", p)
      p.requires_grad = True
    p.data.uniform_(-args.param_init, args.param_init)

  classify_optim.set_parameters(classifier_model.parameters())
  topic_optim.set_parameters(topic_decoder[0].parameters())

  if args.load:
    if args.latest:
      best_model = torch.load(args.save_dir+"/"+args.model_name+"_latestmodel")
    else:
      best_model = torch.load(args.save_dir+"/"+args.model_name+"_bestmodel")
  else:
    try:
      best_valid_loss = None
      best_model = None

      #pretraining the classifier
      for epoch in range(1, args.pretrain_epochs+1):
        pretrain_classifier(classifier_model, train_iter, classify_optim, classify_criterion, args, epoch)
        loss = evaluate(classifier_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args)
        #oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest")

        if not best_valid_loss or loss < best_valid_loss:
          best_valid_loss = loss
          print ("Updating best pretrained_model")
          best_model = copy.deepcopy(classifier_model)
          torch.save(best_model, args.save_dir+"/"+args.model_name+"_pretrained_bestmodel")
        torch.save(classifier_model, args.save_dir+"/"+args.model_name+"_pretrained_latestmodel")

      print("Done pretraining")
      print()
      best_valid_loss = None
      best_model = None
      #alternating training like GANs
      for epoch in range(1, args.epochs + 1):
        for t_step in range(1, args.t_steps+1):
          print()
          print("Training topic predictor")
          train_topic_predictor(classifier_model, topic_decoder[-1], train_iter, topic_optim, topic_criterion, args, epoch, args.t_steps)

        if args.reset_classifier:
          for p in classifier_model.parameters():
            if not p.requires_grad:
              print ("OMG", p)
              p.requires_grad = True
            p.data.uniform_(-args.param_init, args.param_init)

        for c_step in range(1, args.c_steps+1):
          print()
          print("Training classifier")
          train_classifier(classifier_model, topic_decoder, train_iter, classify_optim, classify_criterion, topic_criterion, args, epoch, args.c_steps)
          loss = evaluate(classifier_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args)
          #oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest")

        #creating a new instance of a decoder
        attention_dim = args.hidden if not args.bi else 2*args.hidden
        if args.bottleneck_dim == 0:
          topic_decoder.append(nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(attention_dim, args.num_topics), nn.LogSoftmax()))
        else:
          topic_decoder.append(nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(args.bottleneck_dim, args.num_topics), nn.LogSoftmax()))

        #attaching a new optimizer to the new topic decode
        topic_decoder[-1].to(device)
        topic_optim = Optim(args.optim, args.lr, args.clip)
        for p in topic_decoder[-1].parameters():
          if not p.requires_grad:
            print ("OMG", p)
            p.requires_grad = True
          p.data.uniform_(-args.param_init, args.param_init)
        topic_optim.set_parameters(topic_decoder[-1].parameters())

        if not best_valid_loss or loss < best_valid_loss:
          best_valid_loss = loss
          print ("Updating best model")
          best_model = copy.deepcopy(classifier_model)
          torch.save(best_model, args.save_dir+"/"+args.model_name+"_bestmodel")
        torch.save(classifier_model, args.save_dir+"/"+args.model_name+"_latestmodel")

    except KeyboardInterrupt:
      print("[Ctrl+C] Training stopped!")


  if args.finetune:
    best_valid_loss = None
    for c_step in range(1, args.c_steps+1):
      print()
      print("Fine-tuning classifier")
      train_classifier(classifier_model, None, train_iter, classify_optim, classify_criterion, None, args, c_step, args.c_steps)
      loss = evaluate(classifier_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args)

      if not best_valid_loss or loss < best_valid_loss:
        best_valid_loss = loss
        print ("Updating best model")
        best_model = copy.deepcopy(classifier_model)
        torch.save(best_model, args.save_dir+"/"+args.model_name+"finetune_bestmodel")
      torch.save(classifier_model, args.save_dir+"/"+args.model_name+"finetune_latestmodel")


  if not args.load:
    trainloss = evaluate(best_model, topic_decoder, train_iter, classify_criterion, topic_criterion, args, datatype='train', itos=TEXT.vocab.itos, litos=LABEL.vocab.itos)
    valloss = evaluate(best_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args, datatype='valid', itos=TEXT.vocab.itos, litos=LABEL.vocab.itos)
  loss = evaluate(best_model, topic_decoder, test_iter, classify_criterion, topic_criterion, args, datatype=os.path.basename(args.test_file).replace(".txt", "").replace(".tsv", ""), itos=TEXT.vocab.itos, litos=LABEL.vocab.itos)

  if args.ood_test_file:
    loss = evaluate(best_model, topic_decoder, test_iter, classify_criterion, topic_criterion, args, datatype=os.path.basename(args.ood_test_file).replace(".txt", "").replace(".tsv", ""), itos=TEXT.vocab.itos, litos=LABEL.vocab.itos)