def main():
    args = make_parser().parse_args()
    print("[Model hyperparams]: {}".format(str(args)))

    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cpu") if not cuda else torch.device("cuda:" +
                                                               str(args.gpu))
    seed_everything(seed=1337, cuda=cuda)
    vectors = None  #don't use pretrained vectors
    # vectors = load_pretrained_vectors(args.emsize)

    # Load dataset iterators
    iters, TEXT, LABEL, PROMPTS = dataset_map[args.data](
        args.batch_size,
        device=device,
        vectors=vectors,
        base_path=args.base_path)

    # Some datasets just have the train & test sets, so we just pretend test is valid
    if len(iters) >= 4:
        train_iter = iters[0]
        val_iter = iters[1]
        test_iter = iters[2]
        outdomain_test_iter = list(iters[3:])
    elif len(iters) == 3:
        train_iter, val_iter, test_iter = iters
    else:
        train_iter, test_iter = iters
        val_iter = test_iter

    print("[Corpus]: train: {}, test: {}, vocab: {}, labels: {}".format(
        len(train_iter.dataset), len(test_iter.dataset), len(TEXT.vocab),
        len(LABEL.vocab)))

    if args.model == "CNN":
        args.embed_num = len(TEXT.vocab)
        args.nlabels = len(LABEL.vocab)
        args.nprompts = len(PROMPTS.vocab)
        args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
        args.embed_dim = args.emsize
        classifier_model = CNN_Text_GANLike(args)
        topic_decoder = [
            nn.Sequential(
                nn.Linear(
                    len(args.kernel_sizes) * args.kernel_num, args.nprompts))
        ]

    else:
        ntokens, nlabels, nprompts = len(TEXT.vocab), len(LABEL.vocab), len(
            PROMPTS.vocab)
        args.nlabels = nlabels  # hack to not clutter function arguments
        args.nprompts = nprompts
        embedding = nn.Embedding(ntokens, args.emsize, padding_idx=1)
        encoder = Encoder(args.emsize,
                          args.hidden,
                          nlayers=args.nlayers,
                          dropout=args.drop,
                          bidirectional=args.bi,
                          rnn_type=args.rnn_model)

        attention_dim = args.hidden if not args.bi else 2 * args.hidden
        attention = BahdanauAttention(attention_dim, attention_dim)

        if args.bottleneck_dim == 0:
            classifier_model = Classifier_GANLike(embedding, encoder,
                                                  attention, attention_dim,
                                                  nlabels)
            topic_decoder = [
                nn.Sequential(nn.Dropout(args.topic_drop),
                              nn.Linear(attention_dim, args.nprompts))
            ]
        else:
            classifier_model = Classifier_GANLike_bottleneck(
                embedding,
                encoder,
                attention,
                attention_dim,
                nlabels,
                bottleneck_dim=args.bottleneck_dim)
            topic_decoder = [
                nn.Sequential(nn.Dropout(args.topic_drop),
                              nn.Linear(args.bottleneck_dim, args.nprompts))
            ]

    classifier_model.to(device)
    topic_decoder[0].to(device)

    classify_criterion = nn.CrossEntropyLoss()
    topic_criterion = nn.CrossEntropyLoss()

    classify_optim = Optim(args.optim, args.lr, args.clip)
    topic_optim = Optim(args.optim, args.lr, args.clip)

    for p in classifier_model.parameters():
        if not p.requires_grad:
            print("OMG", p)
            p.requires_grad = True
        p.data.uniform_(-args.param_init, args.param_init)

    for p in topic_decoder[0].parameters():
        if not p.requires_grad:
            print("OMG", p)
            p.requires_grad = True
        p.data.uniform_(-args.param_init, args.param_init)

    classify_optim.set_parameters(classifier_model.parameters())
    topic_optim.set_parameters(topic_decoder[0].parameters())

    if args.load:
        if args.latest:
            best_model = torch.load(args.save_dir + "/" + args.model_name +
                                    "_latestmodel")
        else:
            best_model = torch.load(args.save_dir + "/" + args.model_name +
                                    "_bestmodel")
    else:
        try:
            best_valid_loss = None
            best_model = None

            #pretraining the classifier
            for epoch in range(1, args.pretrain_epochs + 1):
                pretrain_classifier(classifier_model, train_iter,
                                    classify_optim, classify_criterion, args,
                                    epoch)
                loss = evaluate(classifier_model, val_iter, classify_criterion,
                                args)
                # oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest")

                if not best_valid_loss or loss < best_valid_loss:
                    best_valid_loss = loss
                    print("Updating best pretrained_model")
                    best_model = copy.deepcopy(classifier_model)
                    torch.save(
                        best_model, args.save_dir + "/" + args.model_name +
                        "_pretrained_bestmodel")
                torch.save(
                    classifier_model, args.save_dir + "/" + args.model_name +
                    "_pretrained_latestmodel")

            # classifier_model = best_model
            #alternating training like GANs
            for epoch in range(1, args.epochs + 1):
                for t_step in range(1, args.t_steps + 1):
                    train_topic_predictor(classifier_model, topic_decoder[-1],
                                          train_iter, topic_optim,
                                          topic_criterion, args, epoch,
                                          args.t_steps)

                if args.reset_classifier:
                    for p in classifier_model.parameters():
                        if not p.requires_grad:
                            print("OMG", p)
                            p.requires_grad = True
                        p.data.uniform_(-args.param_init, args.param_init)

                for c_step in range(1, args.c_steps + 1):
                    train_classifier(classifier_model, topic_decoder,
                                     train_iter, classify_optim,
                                     classify_criterion, topic_criterion, args,
                                     epoch, args.c_steps)
                    loss = evaluate(classifier_model, val_iter,
                                    classify_criterion, args)
                    # oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest")

                #creating a new instance of a decoder
                if args.model == "CNN":
                    topic_decoder += [
                        nn.Sequential(
                            nn.Linear(
                                len(args.kernel_sizes) * args.kernel_num,
                                args.nprompts))
                    ]
                else:
                    attention_dim = args.hidden if not args.bi else 2 * args.hidden
                    if args.bottleneck_dim == 0:
                        topic_decoder.append(
                            nn.Sequential(
                                nn.Dropout(args.topic_drop),
                                nn.Linear(attention_dim, args.nprompts)))
                    else:
                        topic_decoder.append(
                            nn.Sequential(
                                nn.Dropout(args.topic_drop),
                                nn.Linear(args.bottleneck_dim, args.nprompts)))

                #attaching a new optimizer to the new topic decode
                topic_decoder[-1].to(device)
                topic_optim = Optim(args.optim, args.lr, args.clip)
                for p in topic_decoder[-1].parameters():
                    if not p.requires_grad:
                        print("OMG", p)
                        p.requires_grad = True
                    p.data.uniform_(-args.param_init, args.param_init)
                topic_optim.set_parameters(topic_decoder[-1].parameters())

                if not best_valid_loss or loss < best_valid_loss:
                    best_valid_loss = loss
                    print("Updating best model")
                    best_model = copy.deepcopy(classifier_model)
                    torch.save(
                        best_model,
                        args.save_dir + "/" + args.model_name + "_bestmodel")
                torch.save(
                    classifier_model,
                    args.save_dir + "/" + args.model_name + "_latestmodel")

        except KeyboardInterrupt:
            print("[Ctrl+C] Training stopped!")

    # if not args.load:
    trainloss = evaluate(best_model,
                         train_iter,
                         classify_criterion,
                         args,
                         datatype='train',
                         writetopics=args.save_output_topics,
                         itos=TEXT.vocab.itos,
                         litos=LABEL.vocab.itos)
    valloss = evaluate(best_model,
                       val_iter,
                       classify_criterion,
                       args,
                       datatype='valid',
                       writetopics=args.save_output_topics,
                       itos=TEXT.vocab.itos,
                       litos=LABEL.vocab.itos)

    loss = evaluate(best_model,
                    test_iter,
                    classify_criterion,
                    args,
                    datatype='test',
                    writetopics=args.save_output_topics,
                    itos=TEXT.vocab.itos,
                    litos=LABEL.vocab.itos)
    if args.data == "AMAZON":
        oodnames = args.oodname.split(",")
        for oodname, oodtest_iter in zip(oodnames, outdomain_test_iter):
            oodLoss = evaluate(best_model,
                               oodtest_iter,
                               classify_criterion,
                               args,
                               datatype=oodname + "_bestmodel",
                               writetopics=args.save_output_topics)
            oodLoss = evaluate(classifier_model,
                               oodtest_iter,
                               classify_criterion,
                               args,
                               datatype=oodname + "_latest",
                               writetopics=args.save_output_topics)
    else:
        oodLoss = evaluate(best_model,
                           outdomain_test_iter[0],
                           classify_criterion,
                           args,
                           datatype="oodtest_bestmodel",
                           writetopics=args.save_output_topics,
                           itos=TEXT.vocab.itos,
                           litos=LABEL.vocab.itos)
        oodLoss = evaluate(classifier_model,
                           outdomain_test_iter[0],
                           classify_criterion,
                           args,
                           datatype="oodtest_latest",
                           writetopics=args.save_output_topics)
Exemplo n.º 2
0
def main():

    print("Loading data from '%s'" % opt.data)

    dataset = torch.load(opt.data)

    dict_checkpoint = opt.train_from if opt.train_from else opt.train_from_state_dict
    if dict_checkpoint:
        print('Loading dicts from checkpoint at %s' % dict_checkpoint)
        checkpoint = torch.load(dict_checkpoint)
        dataset['dicts'] = checkpoint['dicts']

    trainData = Dataset(dataset['train']['src'], dataset['train']['tgt'],
                        opt.batch_size, opt.gpus)
    validData = Dataset(dataset['valid']['src'],
                        dataset['valid']['tgt'],
                        opt.batch_size,
                        opt.gpus,
                        volatile=True)

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (len(dicts["word2index"]['src']), len(dicts["word2index"]['tgt'])))
    print(' * number of training sentences. %d' % len(dataset['train']['src']))
    print(' * maximum batch size. %d' % opt.batch_size)

    print('Building model...')

    encoder = Encoder(opt, len(dicts["word2index"]['src']))
    decoder = Decoder(opt, len(dicts["word2index"]['tgt']))

    generator = nn.Sequential(
        nn.Linear(opt.hidden_size * 2, len(dicts["word2index"]['tgt'])),
        nn.LogSoftmax())

    model = NMTModel(encoder, decoder)

    if opt.train_from:
        print('Loading model from checkpoint at %s' % opt.train_from)
        chk_model = checkpoint['model']
        generator_state_dict = chk_model.generator.state_dict()
        model_state_dict = {
            k: v
            for k, v in chk_model.state_dict().items() if 'generator' not in k
        }
        model.load_state_dict(model_state_dict)
        generator.load_state_dict(generator_state_dict)
        opt.start_epoch = checkpoint['epoch'] + 1

    if opt.train_from_state_dict:
        print('Loading model from checkpoint at %s' %
              opt.train_from_state_dict)
        model.load_state_dict(checkpoint['model'])
        generator.load_state_dict(checkpoint['generator'])
        opt.start_epoch = checkpoint['epoch'] + 1

    if len(opt.gpus) >= 1:
        model.cuda()
        generator.cuda()
    else:
        model.cpu()
        generator.cpu()

    if len(opt.gpus) > 1:
        model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
        generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

    model.generator = generator

    if not opt.train_from_state_dict and not opt.train_from:
        for p in model.parameters():
            p.data.uniform_(-opt.param_init, opt.param_init)

        encoder.load_pretrained_vectors(opt)
        decoder.load_pretrained_vectors(opt)

        optim = Optim(opt.optim,
                      opt.learning_rate,
                      opt.max_grad_norm,
                      lr_decay=opt.learning_rate_decay,
                      start_decay_at=opt.start_decay_at)
    else:
        print('Loading optimizer from checkpoint:')
        optim = checkpoint['optim']
        print(optim)

    optim.set_parameters(model.parameters())

    if opt.train_from or opt.train_from_state_dict:
        optim.optimizer.load_state_dict(
            checkpoint['optim'].optimizer.state_dict())

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)

    criterion = NMTCriterion(len(dicts["word2index"]['tgt']))

    trainModel(model, trainData, validData, dataset, optim, criterion)
Exemplo n.º 3
0
def main(hparams: HParams):
    '''
    setup training.
    '''
    if torch.cuda.is_available() and not hparams.gpus:
        warnings.warn(
            'WARNING: you have a CUDA device, so you should probably run with -gpus 0'
        )

    device = torch.device(hparams.gpus if torch.cuda.is_available() else 'cpu')

    # data setup
    print(f"Loading vocabulary...")
    text_preprocessor = TextPreprocessor.load(hparams.preprocessor_path)

    transform = transforms.Compose([
        transforms.Resize([hparams.img_size, hparams.img_size]),
        transforms.RandomCrop([hparams.crop_size, hparams.crop_size]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # create dataloader
    print('Creating DataLoader...')
    normal_data_loader = get_image_caption_loader(
        hparams.img_dir,
        hparams.normal_caption_path,
        text_preprocessor,
        hparams.normal_batch_size,
        transform,
        shuffle=True,
        num_workers=hparams.num_workers,
    )

    style_data_loader = get_caption_loader(
        hparams.style_caption_path,
        text_preprocessor,
        batch_size=hparams.style_batch_size,
        shuffle=True,
        num_workers=hparams.num_workers,
    )

    if hparams.train_from:
        # loading checkpoint
        print('Loading checkpoint...')
        checkpoint = torch.load(hparams.train_from)
    else:
        normal_opt = Optim(
            hparams.optimizer,
            hparams.normal_lr,
            hparams.max_grad_norm,
            hparams.lr_decay,
            hparams.start_decay_at,
        )
        style_opt = Optim(
            hparams.optimizer,
            hparams.style_lr,
            hparams.max_grad_norm,
            hparams.lr_decay,
            hparams.start_decay_at,
        )

    print('Building model...')
    encoder = EncoderCNN(hparams.hidden_dim)
    decoder = FactoredLSTM(hparams.embed_dim,
                           text_preprocessor.vocab_size,
                           hparams.hidden_dim,
                           hparams.style_dim,
                           hparams.num_layers,
                           hparams.random_init,
                           hparams.dropout_ratio,
                           train=True,
                           device=device)

    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=text_preprocessor.PAD_ID)
    normal_params = list(encoder.parameters()) + list(
        decoder.default_parameters())
    style_params = list(decoder.style_parameters())
    normal_opt.set_parameters(normal_params)
    style_opt.set_parameters(style_params)

    if hparams.train_from:
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        normal_opt.load_state_dict(checkpoint['normal_opt'])
        style_opt.load_state_dict(checkpoint['style_opt'])

    # traininig loop
    print('Start training...')
    for epoch in range(hparams.num_epoch):

        # result
        sum_normal_loss, sum_style_loss, sum_normal_ppl, sum_style_ppl = 0, 0, 0, 0

        # normal caption
        for i, (images, in_captions, out_captions,
                lengths) in enumerate(normal_data_loader):
            images = images.to(device)
            in_captions = in_captions.to(device)
            out_captions = out_captions.contiguous().view(-1).to(device)

            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(in_captions, features, mode='default')
            loss = criterion(outputs.view(-1, outputs.size(-1)), out_captions)
            encoder.zero_grad()
            decoder.zero_grad()
            loss.backward()
            normal_opt.step()

            # print log
            sum_normal_loss += loss.item()
            sum_normal_ppl += np.exp(loss.item())
            if i % hparams.normal_log_step == 0:
                print(
                    f'Epoch [{epoch}/{hparams.num_epoch}], Normal Step: [{i}/{len(normal_data_loader)}] '
                    f'Normal Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}'
                )

        # style caption
        for i, (in_captions, out_captions,
                lengths) in enumerate(style_data_loader):
            in_captions = in_captions.to(device)
            out_captions = out_captions.contiguous().view(-1).to(device)

            # Forward, backward and optimize
            outputs = decoder(in_captions, None, mode='style')
            loss = criterion(outputs.view(-1, outputs.size(-1)), out_captions)

            decoder.zero_grad()
            loss.backward()
            style_opt.step()

            sum_style_loss += loss.item()
            sum_style_ppl += np.exp(loss.item())
            # print log
            if i % hparams.style_log_step == 0:
                print(
                    f'Epoch [{epoch}/{hparams.num_epoch}], Style Step: [{i}/{len(style_data_loader)}] '
                    f'Style Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}'
                )

        model_params = {
            'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'epoch': epoch,
            'normal_opt': normal_opt.optimizer.state_dict(),
            'style_opt': style_opt.optimizer.state_dict(),
        }

        avg_normal_loss = sum_normal_loss / len(normal_data_loader)
        avg_style_loss = sum_style_loss / len(style_data_loader)
        avg_normal_ppl = sum_normal_ppl / len(normal_data_loader)
        avg_style_ppl = sum_style_ppl / len(style_data_loader)
        print(f'Epoch [{epoch}/{hparams.num_epoch}] statistics')
        print(
            f'Normal Loss: {avg_normal_loss:.4f} Normal ppl: {avg_normal_ppl:5.4f} '
            f'Style Loss: {avg_style_loss:.4f} Style ppl: {avg_style_ppl:5.4f}'
        )

        torch.save(
            model_params,
            f'{hparams.model_path}/n-loss_{avg_normal_loss:.4f}_s-loss_{avg_style_loss:.4f}_'
            f'n-ppl_{avg_normal_ppl:5.4f}_s-ppl_{avg_style_ppl:5.4f}_epoch_{epoch}.pt'
        )
def main():
  args = make_parser().parse_args()
  print("[Model hyperparams]: {}".format(str(args)))

  cuda = torch.cuda.is_available() and args.cuda
  device = torch.device("cpu") if not cuda else torch.device("cuda:"+str(args.gpu))
  seed_everything(seed=1337, cuda=cuda)
  vectors = None #don't use pretrained vectors
  # vectors = load_pretrained_vectors(args.emsize)

  # Load dataset iterators
  if args.data in ["RT_GENDER"]:
    if args.finetune:
      iters, TEXT, LABEL, INDEX = make_rt_gender(args.batch_size, base_path=args.base_path, train_file=args.train_file, valid_file=args.valid_file, test_file=args.test_file, device=device, vectors=vectors, topics=False)
    else:
      iters, TEXT, LABEL, TOPICS, INDEX = make_rt_gender(args.batch_size, base_path=args.base_path, train_file=args.train_file, valid_file=args.valid_file, test_file=args.test_file, device=device, vectors=vectors, topics=True)
    train_iter, val_iter, test_iter = iters
  else:
    assert False

  if not args.finetune:
    for batch in train_iter:
      args.num_topics = batch.topics.shape[1]
      break

  print("[Corpus]: train: {}, test: {}, vocab: {}, labels: {}".format(
            len(train_iter.dataset), len(test_iter.dataset), len(TEXT.vocab), len(LABEL.vocab)))

  if args.model == "CNN":
    args.embed_num = len(TEXT.vocab)
    args.nlabels = len(LABEL.vocab)
    args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
    args.embed_dim = args.emsize
    classifier_model = CNN_Text_GANLike(args)
    topic_decoder = nn.Sequential(nn.Linear(len(args.kernel_sizes)*args.kernel_num, args.num_topics), nn.LogSoftmax(dim=-1))

  else:
    ntokens, nlabels = len(TEXT.vocab), len(LABEL.vocab)
    args.nlabels = nlabels # hack to not clutter function arguments

    embedding = nn.Embedding(ntokens, args.emsize, padding_idx=1)
    encoder = Encoder(args.emsize, args.hidden, nlayers=args.nlayers,
                      dropout=args.drop, bidirectional=args.bi, rnn_type=args.rnn_model)

    attention_dim = args.hidden if not args.bi else 2*args.hidden
    attention = BahdanauAttention(attention_dim, attention_dim)

    if args.bottleneck_dim == 0:
      classifier_model = Classifier_GANLike(embedding, encoder, attention, attention_dim, nlabels)
      topic_decoder = [nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(attention_dim, args.num_topics), nn.LogSoftmax())]
    else:
      classifier_model = Classifier_GANLike_bottleneck(embedding, encoder, attention, attention_dim, nlabels, bottleneck_dim=args.bottleneck_dim)
      topic_decoder = [nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(args.bottleneck_dim, args.num_topics), nn.LogSoftmax())]


  classifier_model.to(device)
  topic_decoder[0].to(device)

  classify_criterion = nn.CrossEntropyLoss()
  topic_criterion = nn.KLDivLoss(size_average=False)

  classify_optim = Optim(args.optim, args.lr, args.clip)
  topic_optim = Optim(args.optim, args.lr, args.clip)

  for p in classifier_model.parameters():
    if not p.requires_grad:
      print ("OMG", p)
      p.requires_grad = True
    p.data.uniform_(-args.param_init, args.param_init)

  for p in topic_decoder[0].parameters():
    if not p.requires_grad:
      print ("OMG", p)
      p.requires_grad = True
    p.data.uniform_(-args.param_init, args.param_init)

  classify_optim.set_parameters(classifier_model.parameters())
  topic_optim.set_parameters(topic_decoder[0].parameters())

  if args.load:
    if args.latest:
      best_model = torch.load(args.save_dir+"/"+args.model_name+"_latestmodel")
    else:
      best_model = torch.load(args.save_dir+"/"+args.model_name+"_bestmodel")
  else:
    try:
      best_valid_loss = None
      best_model = None

      #pretraining the classifier
      for epoch in range(1, args.pretrain_epochs+1):
        pretrain_classifier(classifier_model, train_iter, classify_optim, classify_criterion, args, epoch)
        loss = evaluate(classifier_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args)
        #oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest")

        if not best_valid_loss or loss < best_valid_loss:
          best_valid_loss = loss
          print ("Updating best pretrained_model")
          best_model = copy.deepcopy(classifier_model)
          torch.save(best_model, args.save_dir+"/"+args.model_name+"_pretrained_bestmodel")
        torch.save(classifier_model, args.save_dir+"/"+args.model_name+"_pretrained_latestmodel")

      print("Done pretraining")
      print()
      best_valid_loss = None
      best_model = None
      #alternating training like GANs
      for epoch in range(1, args.epochs + 1):
        for t_step in range(1, args.t_steps+1):
          print()
          print("Training topic predictor")
          train_topic_predictor(classifier_model, topic_decoder[-1], train_iter, topic_optim, topic_criterion, args, epoch, args.t_steps)

        if args.reset_classifier:
          for p in classifier_model.parameters():
            if not p.requires_grad:
              print ("OMG", p)
              p.requires_grad = True
            p.data.uniform_(-args.param_init, args.param_init)

        for c_step in range(1, args.c_steps+1):
          print()
          print("Training classifier")
          train_classifier(classifier_model, topic_decoder, train_iter, classify_optim, classify_criterion, topic_criterion, args, epoch, args.c_steps)
          loss = evaluate(classifier_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args)
          #oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest")

        #creating a new instance of a decoder
        attention_dim = args.hidden if not args.bi else 2*args.hidden
        if args.bottleneck_dim == 0:
          topic_decoder.append(nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(attention_dim, args.num_topics), nn.LogSoftmax()))
        else:
          topic_decoder.append(nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(args.bottleneck_dim, args.num_topics), nn.LogSoftmax()))

        #attaching a new optimizer to the new topic decode
        topic_decoder[-1].to(device)
        topic_optim = Optim(args.optim, args.lr, args.clip)
        for p in topic_decoder[-1].parameters():
          if not p.requires_grad:
            print ("OMG", p)
            p.requires_grad = True
          p.data.uniform_(-args.param_init, args.param_init)
        topic_optim.set_parameters(topic_decoder[-1].parameters())

        if not best_valid_loss or loss < best_valid_loss:
          best_valid_loss = loss
          print ("Updating best model")
          best_model = copy.deepcopy(classifier_model)
          torch.save(best_model, args.save_dir+"/"+args.model_name+"_bestmodel")
        torch.save(classifier_model, args.save_dir+"/"+args.model_name+"_latestmodel")

    except KeyboardInterrupt:
      print("[Ctrl+C] Training stopped!")


  if args.finetune:
    best_valid_loss = None
    for c_step in range(1, args.c_steps+1):
      print()
      print("Fine-tuning classifier")
      train_classifier(classifier_model, None, train_iter, classify_optim, classify_criterion, None, args, c_step, args.c_steps)
      loss = evaluate(classifier_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args)

      if not best_valid_loss or loss < best_valid_loss:
        best_valid_loss = loss
        print ("Updating best model")
        best_model = copy.deepcopy(classifier_model)
        torch.save(best_model, args.save_dir+"/"+args.model_name+"finetune_bestmodel")
      torch.save(classifier_model, args.save_dir+"/"+args.model_name+"finetune_latestmodel")


  if not args.load:
    trainloss = evaluate(best_model, topic_decoder, train_iter, classify_criterion, topic_criterion, args, datatype='train', itos=TEXT.vocab.itos, litos=LABEL.vocab.itos)
    valloss = evaluate(best_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args, datatype='valid', itos=TEXT.vocab.itos, litos=LABEL.vocab.itos)
  loss = evaluate(best_model, topic_decoder, test_iter, classify_criterion, topic_criterion, args, datatype=os.path.basename(args.test_file).replace(".txt", "").replace(".tsv", ""), itos=TEXT.vocab.itos, litos=LABEL.vocab.itos)

  if args.ood_test_file:
    loss = evaluate(best_model, topic_decoder, test_iter, classify_criterion, topic_criterion, args, datatype=os.path.basename(args.ood_test_file).replace(".txt", "").replace(".tsv", ""), itos=TEXT.vocab.itos, litos=LABEL.vocab.itos)