def generate(generations, population, nn_param_choices, dataset):
    """Generate a network with the genetic algorithm.
    Args:
        generations (int): Number of times to evole the population
        population (int): Number of networks in each generation
        nn_param_choices (dict): Parameter choices for networks
        dataset (str): Dataset to use for training/evaluating
    """
    optimizer = Optimizer(nn_param_choices)
    networks = optimizer.create_population(population)

    # Evolve the generation.
    for i in range(generations):
        logging.info("***Doing generation %d of %d***" % (i + 1, generations))

        # Train and get accuracy for networks.
        train_networks(networks, dataset)

        # Get the average accuracy for this generation.
        average_accuracy = get_average_accuracy(networks)

        # Print out the average accuracy each generation.
        logging.info("Generation average: %.2f%%" % (average_accuracy * 100))
        logging.info('-' * 80)

        # Evolve, except on the last iteration.
        if i != generations - 1:
            # Do the evolution.
            networks = optimizer.evolve(networks)

    # Sort our final population.
    networks = sorted(networks, key=lambda x: x.accuracy, reverse=True)

    # Print out the top 5 networks.
    print_networks(networks[:5])
Пример #2
0
    def __init__(self, config):
        self.cfg = config
        self.net = RegionGroundingModel(config)
        if self.cfg.cuda:
            self.net = self.net.cuda()
        params = filter(lambda p: p.requires_grad, self.net.parameters())
        raw_optimizer = optim.Adam(params, lr=self.cfg.lr)
        optimizer = Optimizer(raw_optimizer,
                              max_grad_norm=self.cfg.grad_norm_clipping)
        if self.cfg.coco_mode >= 0:
            scheduler = optim.lr_scheduler.StepLR(optimizer.optimizer,
                                                  step_size=75,
                                                  gamma=0.1)
            optimizer.set_scheduler(scheduler)
        self.optimizer = optimizer
        self.epoch = 0
        if self.cfg.pretrained is not None:
            self.load_pretrained_net(self.cfg.pretrained)

        print('-------------------')
        print('All parameters')
        for name, param in self.net.named_parameters():
            print(name, param.size())
        print('-------------------')
        print('Trainable parameters')
        for name, param in self.net.named_parameters():
            if param.requires_grad:
                print(name, param.size())
Пример #3
0
def main_worker(rank, args):
    args.rank = rank
    args = setup(args)

    loaders = Data(args).get_loader()
    model = Model(args)
    optimizer = Optimizer(args, model)
    if args.amp:
        model = optimizer.set_amp(model)
    model.parallelize()

    criterion = Loss(args, model=model, optimizer=optimizer)

    trainer = Trainer(args, model, criterion, optimizer, loaders)

    if args.stay:
        interact(local=locals())
        exit()

    if args.demo:
        trainer.evaluate(epoch=args.startEpoch, mode='demo')
        exit()

    for epoch in range(1, args.startEpoch):
        if args.do_validate:
            if epoch % args.validate_every == 0:
                trainer.fill_evaluation(epoch, 'val')
        if args.do_test:
            if epoch % args.test_every == 0:
                trainer.fill_evaluation(epoch, 'test')

    for epoch in range(args.startEpoch, args.endEpoch + 1):
        if args.do_train:
            trainer.train(epoch)

        if args.do_validate:
            if epoch % args.validate_every == 0:
                if trainer.epoch != epoch:
                    trainer.load(epoch)
                trainer.validate(epoch)

        if args.do_test:
            if epoch % args.test_every == 0:
                if trainer.epoch != epoch:
                    trainer.load(epoch)
                trainer.test(epoch)

        if args.rank == 0 or not args.launched:
            print('')

    trainer.imsaver.join_background()

    cleanup(args)
Пример #4
0
    def __init__(self, dset, conf, save=False):

        # Set batches
        train, val, test = dset.build_batches("relevant")

        # Build model, optimizer, loss and scheduler
        model = self.build_model(conf["model"], dset.vocab, dset.char_vocab)
        opt = Optimizer(model.parameters(), conf["optim"])
        loss = BCE()
        lr_sch = LR_scheduler(opt.opt, conf["optim"])

        # To track early stopping
        self.best = {"val": {"f1": 0}, "test": {}}
        step, stop = 0, 0

        # For max epochs
        for ep in range(conf["train"]["max_epochs"]):
            print("\n\tEpoch %d" % ep)
            for batch in train:
                # set the in training mode.
                model.train()
                # advance step
                step += 1
                # forward pass
                x, y, mask = self.fw_pass(model, batch)
                # measure error
                fw_loss = loss(x, y)
                # backward pass
                opt.train_op(fw_loss)

                # validation
                if step % conf["train"]["val_steps"] == 0:
                    # Set the in testing mode
                    model.eval()
                    # Eval on val set
                    val_metrics = utils.bin_fw_eval(model, self.fw_pass, val,
                                                    step)
                    if val_metrics["f1"] > self.best["val"]["f1"]:
                        # reset Early stop
                        stop = 0
                        # Eval on test set
                        test_metrics = utils.bin_fw_eval(
                            model, self.fw_pass, test, step)
                        self.best = {"val": val_metrics, "test": test_metrics}
                        if save:
                            model.save(step, conf, self.best, opt, lr_sch,
                                       "bin")
                    else:
                        if stop == conf["train"]["patience"]:
                            return
                        stop += 1
                # maybe update lr
                lr_sch.step()
Пример #5
0
    def train(self,
              model,
              data,
              num_epochs=5,
              resume=False,
              dev_data=None,
              optimizer=None,
              teacher_forcing_ratio=0):
        """ Run training for a given model.

        Args:
            model (seq2seq.models): model to run training on, if `resume=True`, it would be
               overwritten by the model loaded from the latest checkpoint.
            data (seq2seq.dataset.dataset.Dataset): dataset object to train on
            num_epochs (int, optional): number of epochs to run (default 5)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
            dev_data (seq2seq.dataset.dataset.Dataset, optional): dev Dataset (default None)
            optimizer (seq2seq.optim.Optimizer, optional): optimizer for training
               (default: Optimizer(pytorch.optim.Adam, max_grad_norm=5))
            teacher_forcing_ratio (float, optional): teaching forcing ratio (default 0)
        Returns:
            model (seq2seq.models): trained model.
        """
        # If training is set to resume
        if resume:
            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                self.expt_dir)
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer

            # A walk around to set optimizing parameters properly
            resume_optim = self.optimizer.optimizer
            defaults = resume_optim.param_groups[0]
            defaults.pop('params', None)
            defaults.pop('initial_lr', None)
            self.optimizer.optimizer = resume_optim.__class__(
                model.parameters(), **defaults)

            start_epoch = resume_checkpoint.epoch
            step = resume_checkpoint.step
        else:
            start_epoch = 1
            step = 0
            if optimizer is None:
                optimizer = Optimizer(optim.Adam(model.parameters()),
                                      max_grad_norm=5)
            self.optimizer = optimizer

        self.logger.info("Optimizer: %s, Scheduler: %s" %
                         (self.optimizer.optimizer, self.optimizer.scheduler))

        self._train_epoches(data,
                            model,
                            num_epochs,
                            start_epoch,
                            step,
                            dev_data=dev_data,
                            teacher_forcing_ratio=teacher_forcing_ratio)
        return model
Пример #6
0
    def __init__(self, db):
        self.db = db
        self.cfg = db.cfg
        self.net = PuzzleModel(db)
        if self.cfg.cuda:
            if self.cfg.parallel and torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        image_encoder_trainable_paras = \
            filter(lambda p: p.requires_grad, net.image_encoder.parameters())
        raw_optimizer = optim.Adam([
                {'params': image_encoder_trainable_paras},
                {'params': net.text_encoder.embedding.parameters(), 'lr': self.cfg.finetune_lr},
                {'params': net.text_encoder.rnn.parameters()},
                {'params': net.what_decoder.parameters()},
                {'params': net.where_decoder.parameters()},
                {'params': net.shape_encoder.parameters()},
            ], lr=self.cfg.lr)
        optimizer = Optimizer(raw_optimizer, max_grad_norm=self.cfg.grad_norm_clipping)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, factor=0.8, patience=3)
        # scheduler = optim.lr_scheduler.StepLR(optimizer.optimizer, step_size=3, gamma=0.8)
        # optimizer.set_scheduler(scheduler)
        self.optimizer = optimizer
        self.epoch = 0

        if self.cfg.pretrained is not None:
            self.load_pretrained_net(self.cfg.pretrained)
Пример #7
0
def train(net: NeuralNet,
          inputs: Tensor,
          targets: Tensor,
          num_epochs: int = 5000,
          iterator: DataIterator = BatchIterator(),
          loss: Loss = MSE(),
          optimizer: Optimizer = SGD()
          ) -> None:
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in iterator(inputs, targets):
            predicted = net.forward(batch.inputs)
            epoch_loss += loss.loss(predicted, batch.targets)
            grad = loss.grad(predicted, batch.targets)
            net.backward(grad)
            optimizer.step(net)
        print(epoch, epoch_loss)
Пример #8
0
def get_model(model_file_path=None, eval=False):
    model = Summarizer()
    optimizer = Optimizer(config.optim, config.lr_coverage if config.cov else config.lr,
                          acc=config.adagrad_init_acc, max_grad_norm=config.max_grad_norm)
    optimizer.set_parameters(model.parameters())

    step, loss = 1, 0
    if model_file_path is not None:
        checkpoint = torch.load(model_file_path)
        step = checkpoint['step']
        loss = checkpoint['loss']

        model_state_dict = dict([(k, v) for k, v in checkpoint['model'].items()])
        model.load_state_dict(model_state_dict, strict=False)

        if not config.cov and not eval:
            optimizer.optim.load_state_dict(checkpoint['optimizer'])
            if config.cuda:
                for state in optimizer.optim.state.values():
                    for k, v in checkpoint.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()
    if config.cuda:
        model = model.cuda()
        optimizer.set_parameters(model.parameters())
    return model, optimizer, step, loss
Пример #9
0
def train(sourceVocabClass, targetVocabClass):
    """Train the Equilid Model from character to language-tagged-token sampleData."""
    # Ensure we have a directory to write to
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    max_len = 1000
    seq2seqModel, loss, srcField, tgtField = create_model(
        sourceVocabClass, targetVocabClass)

    logger.debug("char itos length:{} desc:{}".format(len(srcField.vocab.itos),
                                                      srcField.vocab.itos))
    logger.debug("char stoi length:{} desc:{}".format(len(srcField.vocab.stoi),
                                                      srcField.vocab.stoi))
    logger.debug("lang itos length:{} desc:{}".format(len(tgtField.vocab.itos),
                                                      tgtField.vocab.itos))
    logger.debug("lang stoi length:{} desc:{}".format(len(tgtField.vocab.stoi),
                                                      tgtField.vocab.stoi))

    # get a generator for files in the data directory
    train_dev_pairs = get_file_paths(FLAGS.data_dir, 'train:dev')

    if torch.cuda.is_available():
        loss.cuda()

    print("Training model")
    t = SupervisedTrainer(loss=loss,
                          batch_size=int(FLAGS.batch_size),
                          checkpoint_every=FLAGS.checkpoint_interval,
                          print_every=20,
                          expt_dir=FLAGS.expt_dir)
    adamOptimizer = torch.optim.Adam(seq2seqModel.parameters(),
                                     lr=float(FLAGS.learning_rate))
    optimizer = Optimizer(adamOptimizer, max_grad_norm=FLAGS.max_gradient_norm)

    for train_path, dev_path in train_dev_pairs:
        train_dataset = load_tabular_dataset(train_path, srcField, tgtField,
                                             max_len)
        dev_dataset = load_tabular_dataset(dev_path, srcField, tgtField,
                                           max_len)
        logger.debug("Using Dataset files train:{} dev:{}".format(
            train_path, dev_path))
        seq2seqModel = t.train(seq2seqModel,
                               train_dataset,
                               num_epochs=int(FLAGS.num_epochs),
                               dev_data=dev_dataset,
                               optimizer=optimizer,
                               teacher_forcing_ratio=1,
                               resume=FLAGS.resume)
    print("training completed!")
Пример #10
0
    def __init__(self, config):
        self.cfg = config
        self.net = SynthesisModel(config)
        if self.cfg.cuda:
            if self.cfg.parallel and torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        raw_optimizer = optim.Adam([{
            'params': net.encoder.parameters()
        }, {
            'params': net.decoder.parameters()
        }],
                                   lr=self.cfg.lr)
        optimizer = Optimizer(raw_optimizer)
        self.optimizer = optimizer
        self.epoch = 0
        if self.cfg.pretrained is not None:
            self.load_pretrained_net(self.cfg.pretrained)
Пример #11
0
seq2seq_m = Seq2seq(encoder, decoder)
if torch.cuda.is_available():
    seq2seq_m.cuda()

#initialize random tensor
for param in seq2seq_m.parameters():
    param.data.uniform_(-0.08, 0.08)

t = SupervisedTrainer(loss=loss,
                      batch_size=batch_size,
                      checkpoint_every=50,
                      print_every=10,
                      expt_dir=expt_dir)

optimizer = Optimizer(
    torch.optim.Adam(seq2seq_m.parameters(), lr=0.001, betas=(0.9, 0.999)))
# scheduler = StepLR(optimizer.optimizer, 1)
# optimizer.set_scheduler(scheduler)

################################
seq2seq_m = t.train(seq2seq_m,
                    train,
                    num_epochs=num_epochs,
                    dev_data=dev,
                    optimizer=optimizer,
                    teacher_forcing_ratio=0.5,
                    resume=resume)

e = int(time.time() - start_time)
print('ELAPSED TIME TRAINING ~> {:02d}:{:02d}:{:02d}'.format(
    e // 3600, (e % 3600 // 60), e % 60))
    def __init__(self, dset, conf, save=False):
        # Set batches
        train, val, test = dset.build_batches("tok_tags")

        # Build model
        model = self.build_model(conf["model"], dset.vocab, dset.char_vocab,
                                 dset.tag_vocab)
        opt = Optimizer(model.parameters(), conf["optim"])
        if conf["model"]["use_crf"]:
            loss = CustomLoss(model.crf)
        else:
            loss = WeightedCEL(dset.tag_vocab)
        lr_sch = LR_scheduler(opt.opt, conf["optim"])

        # To track early stopping
        self.best = {"val": {"f1": 0}, "test": {}}
        step, stop = 0, 0

        # Tags to ignore in metrics
        ign_tok = [
            dset.tag_vocab["<p>"], dset.tag_vocab["<s>"],
            dset.tag_vocab["</s>"]
        ]

        for ep in range(conf["train"]["max_epochs"]):
            print("\n\tEpoch %d" % ep)
            for batch in train:
                # set the in training mode.
                model.train()
                # advance step
                step += 1
                # forward pass
                x, y, mask = self.fw_pass(model, batch)
                # measure error
                fw_loss = loss(x, y, mask)
                # backward pass
                opt.train_op(fw_loss)

                # validation
                if step % conf["train"]["val_steps"] == 0:
                    # Set the in testing mode
                    model.eval()
                    # Eval on val set
                    val_metrics = utils.ner_fw_eval(model, self.fw_pass, val,
                                                    step, ign_tok)
                    if val_metrics["f1"] > self.best["val"]["f1"]:
                        # reset Early stop
                        stop = 0
                        # Eval on test set
                        test_metrics = utils.ner_fw_eval(
                            model, self.fw_pass, test, step, ign_tok)
                        self.best = {"val": val_metrics, "test": test_metrics}
                        if save:
                            model.save(step, conf, self.best, opt, lr_sch,
                                       "ner")
                    else:
                        if stop == conf["train"]["patience"]:
                            return
                        stop += 1
                # maybe update lr
                lr_sch.step()
Пример #13
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## Optimizer
        ##################################################################
        image_encoder_trainable_paras = \
            filter(lambda p: p.requires_grad, self.net.image_encoder.parameters())
        raw_optimizer = optim.Adam([
                {'params': self.net.text_encoder.embedding.parameters(), 'lr': self.cfg.finetune_lr},
                {'params': image_encoder_trainable_paras, 'lr': self.cfg.finetune_lr},
                {'params': self.net.text_encoder.rnn.parameters()},
                {'params': self.net.what_decoder.parameters()}, 
                {'params': self.net.where_decoder.parameters()}
            ], lr=self.cfg.lr)
        optimizer = Optimizer(raw_optimizer, max_grad_norm=self.cfg.grad_norm_clipping)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, factor=0.8, patience=3)
        scheduler = optim.lr_scheduler.StepLR(optimizer.optimizer, step_size=3, gamma=0.8)
        optimizer.set_scheduler(scheduler)

        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        for epoch in range(self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_pred_loss, train_attn_loss, train_eos_loss, train_accu = \
                self.train_epoch(train_db, optimizer, epoch)
        
            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss, val_accu, val_infos = self.validate_epoch(val_db)
            
            ##################################################################
            ## Sample
            ##################################################################
            torch.cuda.empty_cache()
            self.sample(epoch, test_db, self.cfg.n_samples)
            torch.cuda.empty_cache()
            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            optimizer.update(np.mean(val_loss), epoch)
                
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)

            logz.log_tabular("TrainAverageError", np.mean(train_pred_loss))
            logz.log_tabular("TrainStdError", np.std(train_pred_loss))
            logz.log_tabular("TrainMaxError", np.max(train_pred_loss))
            logz.log_tabular("TrainMinError", np.min(train_pred_loss))
            logz.log_tabular("TrainAverageAccu", np.mean(train_accu))
            logz.log_tabular("TrainStdAccu", np.std(train_accu))
            logz.log_tabular("TrainMaxAccu", np.max(train_accu))
            logz.log_tabular("TrainMinAccu", np.min(train_accu))
            
            logz.log_tabular("ValAverageError", np.mean(val_loss))
            logz.log_tabular("ValStdError", np.std(val_loss))
            logz.log_tabular("ValMaxError", np.max(val_loss))
            logz.log_tabular("ValMinError", np.min(val_loss))
            logz.log_tabular("ValAverageAccu", np.mean(val_accu))
            logz.log_tabular("ValStdAccu", np.std(val_accu))
            logz.log_tabular("ValMaxAccu", np.max(val_accu))
            logz.log_tabular("ValMinAccu", np.min(val_accu))

            logz.log_tabular("ValAverageObjAccu", np.mean(val_accu[:, 0]))
            logz.log_tabular("ValStdObjAccu", np.std(val_accu[:, 0]))
            logz.log_tabular("ValMaxObjAccu", np.max(val_accu[:, 0]))
            logz.log_tabular("ValMinObjAccu", np.min(val_accu[:, 0]))

            logz.log_tabular("ValAveragePoseAccu", np.mean(val_accu[:, 1]))
            logz.log_tabular("ValStdPoseAccu", np.std(val_accu[:, 1]))
            logz.log_tabular("ValMaxPoseAccu", np.max(val_accu[:, 1]))
            logz.log_tabular("ValMinPoseAccu", np.min(val_accu[:, 1]))

            logz.log_tabular("ValAverageExprAccu", np.mean(val_accu[:, 2]))
            logz.log_tabular("ValStdExprAccu", np.std(val_accu[:, 2]))
            logz.log_tabular("ValMaxExprAccu", np.max(val_accu[:, 2]))
            logz.log_tabular("ValMinExprAccu", np.min(val_accu[:, 2]))

            logz.log_tabular("ValAverageCoordAccu", np.mean(val_accu[:, 3]))
            logz.log_tabular("ValStdCoordAccu", np.std(val_accu[:, 3]))
            logz.log_tabular("ValMaxCoordAccu", np.max(val_accu[:, 3]))
            logz.log_tabular("ValMinCoordAccu", np.min(val_accu[:, 3]))

            logz.log_tabular("ValAverageScaleAccu", np.mean(val_accu[:, 4]))
            logz.log_tabular("ValStdScaleAccu", np.std(val_accu[:, 4]))
            logz.log_tabular("ValMaxScaleAccu", np.max(val_accu[:, 4]))
            logz.log_tabular("ValMinScaleAccu", np.min(val_accu[:, 4]))

            logz.log_tabular("ValAverageFlipAccu", np.mean(val_accu[:, 5]))
            logz.log_tabular("ValStdFlipAccu", np.std(val_accu[:, 5]))
            logz.log_tabular("ValMaxFlipAccu", np.max(val_accu[:, 5]))
            logz.log_tabular("ValMinFlipAccu", np.min(val_accu[:, 5]))


            logz.log_tabular("ValUnigramF3", np.mean(val_infos.unigram_F3()))
            logz.log_tabular("ValBigramF3",  np.mean(val_infos.bigram_F3()))
            logz.log_tabular("ValUnigramP",  np.mean(val_infos.unigram_P()))
            logz.log_tabular("ValUnigramR",  np.mean(val_infos.unigram_R()))
            logz.log_tabular("ValBigramP",   val_infos.mean_bigram_P())
            logz.log_tabular("ValBigramR",   val_infos.mean_bigram_R())

            logz.log_tabular("ValUnigramPose",  np.mean(val_infos.pose()))
            logz.log_tabular("ValUnigramExpr",  np.mean(val_infos.expr()))
            logz.log_tabular("ValUnigramScale", np.mean(val_infos.scale()))
            logz.log_tabular("ValUnigramFlip",  np.mean(val_infos.flip()))
            logz.log_tabular("ValUnigramSim",   np.mean(val_infos.unigram_coord()))
            logz.log_tabular("ValBigramSim",    val_infos.mean_bigram_coord())

            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            log_info = [np.mean(val_loss), np.mean(val_accu)]
            self.save_checkpoint(epoch, log_info)
            torch.cuda.empty_cache()
Пример #14
0
                             dropout_p=0.2,
                             use_attention=True,
                             bidirectional=bidirectional,
                             eos_id=tgt.eos_id,
                             sos_id=tgt.sos_id)
        seq2seq = Seq2seq(encoder, decoder)
        if torch.cuda.is_available():
            seq2seq.cuda()

        for param in seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)

            # Optimizer and learning rate scheduler can be customized by
            # explicitly constructing the objects and pass to the trainer.
            #
            optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()),
                                  max_grad_norm=5)
            scheduler = StepLR(optimizer.optimizer, 1)
            optimizer.set_scheduler(scheduler)

    # train
    t = SupervisedTrainer(loss=loss,
                          batch_size=32,
                          checkpoint_every=50,
                          print_every=10,
                          expt_dir=opt.expt_dir)

    seq2seq = t.train(seq2seq,
                      train,
                      num_epochs=20000,
                      dev_data=dev,
                      optimizer=optimizer,
Пример #15
0
def train(args):
    """Run model training."""

    print("Start Training ...")

    # Get nested namespaces.
    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    # Get logger.
    print('Getting logger... log to path: {}'.format(logger_args.log_path))
    logger = Logger(logger_args.log_path, logger_args.save_dir)

    # For conaug, point to the MOCO pretrained weights.
    if model_args.ckpt_path and model_args.ckpt_path != 'None':
        print("pretrained checkpoint specified : {}".format(
            model_args.ckpt_path))
        # CL-specified args are used to load the model, rather than the
        # ones saved to args.json.
        model_args.pretrained = False
        ckpt_path = model_args.ckpt_path
        model, ckpt_info = ModelSaver.load_model(ckpt_path=ckpt_path,
                                                 gpu_ids=args.gpu_ids,
                                                 model_args=model_args,
                                                 is_training=True)

        if not model_args.moco:
            optim_args.start_epoch = ckpt_info['epoch'] + 1
        else:
            optim_args.start_epoch = 1
    else:
        print(
            'Starting without pretrained training checkpoint, random initialization.'
        )
        # If no ckpt_path is provided, instantiate a new randomly
        # initialized model.
        model_fn = models.__dict__[model_args.model]
        if data_args.custom_tasks is not None:
            tasks = NamedTasks[data_args.custom_tasks]
        else:
            tasks = model_args.__dict__[TASKS]  # TASKS = "tasks"
        print("Tasks: {}".format(tasks))
        model = model_fn(tasks, model_args)
        model = nn.DataParallel(model, args.gpu_ids)

    # Put model on gpu or cpu and put into training mode.
    model = model.to(args.device)
    model.train()

    print("========= MODEL ==========")
    print(model)

    # Get train and valid loader objects.
    train_loader = get_loader(phase="train",
                              data_args=data_args,
                              transform_args=transform_args,
                              is_training=True,
                              return_info_dict=False,
                              logger=logger)
    valid_loader = get_loader(phase="valid",
                              data_args=data_args,
                              transform_args=transform_args,
                              is_training=False,
                              return_info_dict=False,
                              logger=logger)

    # Instantiate the predictor class for obtaining model predictions.
    predictor = Predictor(model, args.device)
    # Instantiate the evaluator class for evaluating models.
    evaluator = Evaluator(logger)
    # Get the set of tasks which will be used for saving models
    # and annealing learning rate.
    eval_tasks = EVAL_METRIC2TASKS[optim_args.metric_name]

    # Instantiate the saver class for saving model checkpoints.
    saver = ModelSaver(save_dir=logger_args.save_dir,
                       iters_per_save=logger_args.iters_per_save,
                       max_ckpts=logger_args.max_ckpts,
                       metric_name=optim_args.metric_name,
                       maximize_metric=optim_args.maximize_metric,
                       keep_topk=logger_args.keep_topk)

    # TODO: JBY: handle threshold for fine tuning
    if model_args.fine_tuning == 'full':  # Fine tune all layers.
        pass
    else:
        # Freeze other layers.
        models.PretrainedModel.set_require_grad_for_fine_tuning(
            model, model_args.fine_tuning.split(','))

    # Instantiate the optimizer class for guiding model training.
    optimizer = Optimizer(parameters=model.parameters(),
                          optim_args=optim_args,
                          batch_size=data_args.batch_size,
                          iters_per_print=logger_args.iters_per_print,
                          iters_per_visual=logger_args.iters_per_visual,
                          iters_per_eval=logger_args.iters_per_eval,
                          dataset_len=len(train_loader.dataset),
                          logger=logger)

    if model_args.ckpt_path and not model_args.moco:
        # Load the same optimizer as used in the original training.
        optimizer.load_optimizer(ckpt_path=model_args.ckpt_path,
                                 gpu_ids=args.gpu_ids)

    model_uncertainty = model_args.model_uncertainty
    loss_fn = evaluator.get_loss_fn(
        loss_fn_name=optim_args.loss_fn,
        model_uncertainty=model_args.model_uncertainty,
        mask_uncertain=True,
        device=args.device)

    # Run training
    while not optimizer.is_finished_training():
        optimizer.start_epoch()

        # TODO: JBY, HACK WARNING  # What is the hack?
        metrics = None
        for inputs, targets in train_loader:
            optimizer.start_iter()
            if optimizer.global_step and optimizer.global_step % optimizer.iters_per_eval == 0 or len(
                    train_loader.dataset
            ) - optimizer.iter < optimizer.batch_size:

                # Only evaluate every iters_per_eval examples.
                predictions, groundtruth = predictor.predict(valid_loader)
                # print("predictions: {}".format(predictions))
                metrics, curves = evaluator.evaluate_tasks(
                    groundtruth, predictions)
                # Log metrics to stdout.
                logger.log_metrics(metrics)

                # Add logger for all the metrics for valid_loader
                logger.log_scalars(metrics, optimizer.global_step)

                # Get the metric used to save model checkpoints.
                average_metric = evaluator.evaluate_average_metric(
                    metrics, eval_tasks, optim_args.metric_name)

                if optimizer.global_step % logger_args.iters_per_save == 0:
                    # Only save every iters_per_save examples directly
                    # after evaluation.
                    print("Save global step: {}".format(optimizer.global_step))
                    saver.save(iteration=optimizer.global_step,
                               epoch=optimizer.epoch,
                               model=model,
                               optimizer=optimizer,
                               device=args.device,
                               metric_val=average_metric)

                # Step learning rate scheduler.
                optimizer.step_scheduler(average_metric)

            with torch.set_grad_enabled(True):
                logits, embedding = model(inputs.to(args.device))
                loss = loss_fn(logits, targets.to(args.device))
                optimizer.log_iter(inputs, logits, targets, loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            optimizer.end_iter()

        optimizer.end_epoch(metrics)

    logger.log('=== Training Complete ===')
Пример #16
0
    )

else:
    args.lr = 0.00035
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.lr,
                                 weight_decay=5e-4)
    lr_scheduler = WarmupMultiStepLR(
        optimizer,
        milestones=[200, 400],
        gamma=0.1,
        warmup_epochs=100,
    )

optimizer = Optimizer(optimizer=optimizer,
                      lr_scheduler=lr_scheduler,
                      max_epochs=800)

args.results_dir = os.path.join(
    args.results_dir,
    dataset,
    "{}_pooling_{}_loss_{}".format(args.optim, args.pooling_type,
                                   args.loss_type),
)

if args.non_local:
    args.results_dir = args.results_dir + "_nonlocal"

# run
solver = Engine(
    results_dir=args.results_dir,
Пример #17
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## Optimizer
        ##################################################################
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        image_encoder_trainable_paras = \
            filter(lambda p: p.requires_grad, net.image_encoder.parameters())
        raw_optimizer = optim.Adam([
            {
                'params': image_encoder_trainable_paras
            },
            {
                'params': net.text_encoder.embedding.parameters(),
                'lr': self.cfg.finetune_lr
            },
            {
                'params': net.text_encoder.rnn.parameters()
            },
            {
                'params': net.what_decoder.parameters()
            },
            {
                'params': net.where_decoder.parameters()
            },
            {
                'params': net.shape_encoder.parameters()
            },
        ],
                                   lr=self.cfg.lr)
        optimizer = Optimizer(raw_optimizer,
                              max_grad_norm=self.cfg.grad_norm_clipping)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, factor=0.8, patience=3)
        # scheduler = optim.lr_scheduler.StepLR(optimizer.optimizer, step_size=3, gamma=0.8)
        # optimizer.set_scheduler(scheduler)

        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 100000000
        for epoch in range(self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_loss = self.train_epoch(train_db, optimizer, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss = self.validate_epoch(val_db, epoch)

            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_loss[:, 0])
            # optimizer.update(current_val_loss, epoch)
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("AverageLoss", np.mean(train_loss[:, 0]))
            logz.log_tabular("AverageEmbedLoss", np.mean(train_loss[:, 1]))
            logz.log_tabular("AverageAttnLoss", np.mean(train_loss[:, 2]))
            logz.log_tabular("ValAverageLoss", np.mean(val_loss[:, 0]))
            logz.log_tabular("ValAverageEmbedLoss", np.mean(val_loss[:, 1]))
            logz.log_tabular("ValAverageAttnLoss", np.mean(val_loss[:, 2]))
            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            if min_val_loss > current_val_loss:
                min_val_loss = current_val_loss
                # log_info = [np.mean(val_loss), np.mean(val_accu)]
                # self.save_checkpoint(epoch, log_info)
                self.save_best_checkpoint()
                torch.cuda.empty_cache()
Пример #18
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## Optimizer
        ##################################################################
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        image_encoder_trainable_paras = \
            filter(lambda p: p.requires_grad, net.image_encoder.parameters())
        # raw_optimizer = optim.Adam([
        #         {'params': net.text_encoder.parameters(), 'lr': self.cfg.finetune_lr},
        #         {'params': image_encoder_trainable_paras},
        #         {'params': net.what_decoder.parameters()},
        #         {'params': net.where_decoder.parameters()}
        #     ], lr=self.cfg.lr)
        raw_optimizer = optim.Adam([{
            'params': image_encoder_trainable_paras,
            'initial_lr': self.cfg.lr
        }, {
            'params': net.what_decoder.parameters(),
            'initial_lr': self.cfg.lr
        }, {
            'params': net.where_decoder.parameters(),
            'initial_lr': self.cfg.lr
        }],
                                   lr=self.cfg.lr)
        self.optimizer = Optimizer(raw_optimizer,
                                   max_grad_norm=self.cfg.grad_norm_clipping)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, factor=0.8, patience=3)
        scheduler = optim.lr_scheduler.StepLR(self.optimizer.optimizer,
                                              step_size=3,
                                              gamma=0.8,
                                              last_epoch=self.start_epoch - 1)
        self.optimizer.set_scheduler(scheduler)

        num_train_steps = int(
            len(train_db) / self.cfg.accumulation_steps * self.cfg.n_epochs)
        num_warmup_steps = int(num_train_steps * self.cfg.warmup)
        self.bert_optimizer = AdamW([{
            'params': net.text_encoder.parameters(),
            'initial_lr': self.cfg.finetune_lr
        }],
                                    lr=self.cfg.finetune_lr)
        self.bert_scheduler = get_linear_schedule_with_warmup(
            self.bert_optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_train_steps,
            last_epoch=self.start_epoch - 1)

        bucket_boundaries = [4, 8, 12, 16, 22]  # [4,8,12,16,22]
        print('preparing training bucket sampler')
        self.train_bucket_sampler = BucketSampler(
            train_db, bucket_boundaries, batch_size=self.cfg.batch_size)
        print('preparing validation bucket sampler')
        self.val_bucket_sampler = BucketSampler(val_db,
                                                bucket_boundaries,
                                                batch_size=4)

        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()

        for epoch in range(self.start_epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            print('Training...')
            torch.cuda.empty_cache()
            train_pred_loss, train_attn_loss, train_eos_loss, train_accu, train_mse = \
                self.train_epoch(train_db, self.optimizer, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            print('Validation...')
            val_loss, val_accu, val_mse, val_infos = self.validate_epoch(
                val_db)

            ##################################################################
            ## Sample
            ##################################################################
            if self.cfg.if_sample:
                print('Sample...')
                torch.cuda.empty_cache()
                self.sample(epoch, test_db, self.cfg.n_samples)
                torch.cuda.empty_cache()
            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            print('Loging...')
            self.optimizer.update(np.mean(val_loss), epoch)

            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)

            logz.log_tabular("TrainAverageError", np.mean(train_pred_loss))
            logz.log_tabular("TrainAverageAccu", np.mean(train_accu))
            logz.log_tabular("TrainAverageMse", np.mean(train_mse))
            logz.log_tabular("ValAverageError", np.mean(val_loss))
            logz.log_tabular("ValAverageAccu", np.mean(val_accu))
            logz.log_tabular("ValAverageObjAccu", np.mean(val_accu[:, 0]))
            logz.log_tabular("ValAverageCoordAccu", np.mean(val_accu[:, 1]))
            logz.log_tabular("ValAverageScaleAccu", np.mean(val_accu[:, 2]))
            logz.log_tabular("ValAverageRatioAccu", np.mean(val_accu[:, 3]))
            logz.log_tabular("ValAverageMse", np.mean(val_mse))
            logz.log_tabular("ValAverageXMse", np.mean(val_mse[:, 0]))
            logz.log_tabular("ValAverageYMse", np.mean(val_mse[:, 1]))
            logz.log_tabular("ValAverageWMse", np.mean(val_mse[:, 2]))
            logz.log_tabular("ValAverageHMse", np.mean(val_mse[:, 3]))
            logz.log_tabular("ValUnigramF3", np.mean(val_infos.unigram_F3()))
            logz.log_tabular("ValBigramF3", np.mean(val_infos.bigram_F3()))
            logz.log_tabular("ValUnigramP", np.mean(val_infos.unigram_P()))
            logz.log_tabular("ValUnigramR", np.mean(val_infos.unigram_R()))
            logz.log_tabular("ValBigramP", val_infos.mean_bigram_P())
            logz.log_tabular("ValBigramR", val_infos.mean_bigram_R())
            logz.log_tabular("ValUnigramScale", np.mean(val_infos.scale()))
            logz.log_tabular("ValUnigramRatio", np.mean(val_infos.ratio()))
            logz.log_tabular("ValUnigramSim",
                             np.mean(val_infos.unigram_coord()))
            logz.log_tabular("ValBigramSim", val_infos.mean_bigram_coord())

            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            print('Saving checkpoint...')
            log_info = [np.mean(val_loss), np.mean(val_accu)]
            self.save_checkpoint(epoch, log_info)
            torch.cuda.empty_cache()
Пример #19
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batchsize", "-b", type=int, default=64)
    parser.add_argument("--seq-length", "-l", type=int, default=35)
    parser.add_argument("--total-epochs", "-e", type=int, default=300)
    parser.add_argument("--gpu-device", "-g", type=int, default=0)
    parser.add_argument("--grad-clip", "-gc", type=float, default=5)
    parser.add_argument("--learning-rate", "-lr", type=float, default=1)
    parser.add_argument("--weight-decay", "-wd", type=float, default=0.000001)
    parser.add_argument("--dropout-embedding-softmax",
                        "-dos",
                        type=float,
                        default=0.5)
    parser.add_argument("--dropout-rnn", "-dor", type=float, default=0.2)
    parser.add_argument("--variational-dropout",
                        "-vdo",
                        dest="variational_dropout",
                        action="store_true",
                        default=False)
    parser.add_argument("--use-tanh",
                        "-tanh",
                        dest="use_tanh",
                        action="store_true",
                        default=True)
    parser.add_argument("--use-identity",
                        "-identity",
                        dest="use_tanh",
                        action="store_false")
    parser.add_argument("--momentum", "-mo", type=float, default=0.9)
    parser.add_argument("--optimizer", "-opt", type=str, default="msgd")
    parser.add_argument("--ndim-feature", "-nf", type=int, default=640)
    parser.add_argument("--num-layers", "-nl", type=int, default=2)
    parser.add_argument("--lr-decay-epoch", "-lrd", type=int, default=20)
    parser.add_argument("--model-filename",
                        "-m",
                        type=str,
                        default="model.hdf5")
    args = parser.parse_args()

    print("#layers={}".format(args.num_layers))
    print("d={}".format(args.ndim_feature))
    print("dropout={}".format(
        "Variational" if args.variational_dropout else "Standard"))
    print("g={}".format("tanh" if args.use_tanh else "identity"))

    assert args.num_layers > 0
    assert args.ndim_feature > 0

    dataset_train, dataset_dev, dataset_test = chainer.datasets.get_ptb_words()
    dataset_dev = np.asarray(dataset_dev, dtype=np.int32)

    vocab_size = max(dataset_train) + 1
    rnn = RNN(vocab_size,
              ndim_feature=args.ndim_feature,
              num_layers=args.num_layers,
              use_tanh=args.use_tanh,
              dropout_embedding_softmax=args.dropout_embedding_softmax,
              dropout_rnn=args.dropout_rnn,
              variational_dropout=args.variational_dropout)
    rnn.load(args.model_filename)

    total_iterations_train = len(dataset_train) // (args.seq_length *
                                                    args.batchsize)

    optimizer = Optimizer(args.optimizer, args.learning_rate, args.momentum)
    optimizer.setup(rnn.model)
    if args.grad_clip > 0:
        optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))
    if args.weight_decay > 0:
        optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    using_gpu = False
    if args.gpu_device >= 0:
        cuda.get_device(args.gpu_device).use()
        rnn.model.to_gpu()
        using_gpu = True
    xp = rnn.model.xp

    training_start_time = time.time()
    for epoch in range(args.total_epochs):

        sum_loss = 0
        epoch_start_time = time.time()

        # training
        for itr in range(total_iterations_train):
            # sample minbatch
            batch_offsets = np.random.randint(0,
                                              len(dataset_train) -
                                              args.seq_length - 1,
                                              size=args.batchsize)
            x_batch = np.empty((args.batchsize, args.seq_length),
                               dtype=np.int32)
            t_batch = np.empty((args.batchsize, args.seq_length),
                               dtype=np.int32)
            for batch_index, offset in enumerate(batch_offsets):
                sequence = dataset_train[offset:offset + args.seq_length]
                teacher = dataset_train[offset + 1:offset + args.seq_length +
                                        1]
                x_batch[batch_index] = sequence
                t_batch[batch_index] = teacher

            if using_gpu:
                x_batch = cuda.to_gpu(x_batch)
                t_batch = cuda.to_gpu(t_batch)

            t_batch = flatten(t_batch)

            # update model parameters
            with chainer.using_config("train", True):
                rnn.reset_state()
                y_batch = rnn(x_batch, flatten=True)
                loss = functions.softmax_cross_entropy(y_batch, t_batch)

                rnn.model.cleargrads()
                loss.backward()
                optimizer.update()

                sum_loss += float(loss.data)
                assert sum_loss == sum_loss, "Encountered NaN!"

            printr("Training ... {:3.0f}% ({}/{})".format(
                (itr + 1) / total_iterations_train * 100, itr + 1,
                total_iterations_train))

        rnn.save(args.model_filename)

        # evaluation
        x_sequence = dataset_dev[:-1]
        t_sequence = dataset_dev[1:]
        rnn.reset_state()
        total_iterations_dev = math.ceil(len(x_sequence) / args.seq_length)
        offset = 0
        negative_log_likelihood = 0
        for itr in range(total_iterations_dev):
            seq_length = min(offset + args.seq_length,
                             len(x_sequence)) - offset
            x_batch = x_sequence[None, offset:offset + seq_length]
            t_batch = flatten(t_sequence[None, offset:offset + seq_length])

            if using_gpu:
                x_batch = cuda.to_gpu(x_batch)
                t_batch = cuda.to_gpu(t_batch)

            with chainer.no_backprop_mode() and chainer.using_config(
                    "train", False):
                y_batch = rnn(x_batch, flatten=True)
                negative_log_likelihood += float(
                    functions.softmax_cross_entropy(y_batch,
                                                    t_batch).data) * seq_length

            printr("Computing perplexity ...{:3.0f}% ({}/{})".format(
                (itr + 1) / total_iterations_dev * 100, itr + 1,
                total_iterations_dev))
            offset += seq_length

        assert negative_log_likelihood == negative_log_likelihood, "Encountered NaN!"
        perplexity = math.exp(negative_log_likelihood / len(dataset_dev))

        clear_console()
        print(
            "Epoch {} done in {} sec - loss: {:.6f} - log_likelihood: {} - ppl: {} - lr: {:.3g} - total {} min"
            .format(epoch + 1, int(time.time() - epoch_start_time),
                    sum_loss / total_iterations_train,
                    int(-negative_log_likelihood), int(perplexity),
                    optimizer.get_learning_rate(),
                    int((time.time() - training_start_time) // 60)))

        if epoch >= args.lr_decay_epoch:
            optimizer.decrease_learning_rate(0.98, final_value=1e-5)
Пример #20
0
# y = y / y.max(axis=0)

# Add baseline feature in column 0
X = np.hstack((np.ones(X.shape[0]).reshape(-1, 1), X))
# %% [markdown]
# ## Experiment 1: Numerical Approximation

# %%
cost = LinearCostFunction(X, y)

step_size = 0.1
max_iter = 5000
tol = 1e-8
delta = 1e-5

optimizer = Optimizer(step_size, max_iter, tol, delta)
initial_params = np.zeros(X.shape[1])
optimized_params, iters = optimizer.optimize(cost, initial_params)
print(
    f'Found min at {optimized_params} starting at {initial_params} in {iters} iterations of optimization algorithm.'
)
y_pred = np.sum(X * optimized_params, axis=1)
# y_true = np.sum(X, axis=1)
plt.scatter(y_pred, y)
plt.plot([0, 25], [0, 25])
plt.show()

# %% [markdown]
# ## Experiment 2: Normal Solutions

# %%
Пример #21
0
    if opt.resume and not opt.load_checkpoint:
        last_checkpoint = get_last_checkpoint(opt.best_model_dir)
    if last_checkpoint:
        opt.load_checkpoint = os.path.join(opt.model_dir, last_checkpoint)
        opt.skip_steps = int(last_checkpoint.strip('.pt').split('/')[-1])

    if opt.load_checkpoint:
        model.load_state_dict(torch.load(opt.load_checkpoint))
        opt.skip_steps = int(opt.load_checkpoint.strip('.pt').split('/')[-1])
        logger.info(f"\nLoad from {opt.load_checkpoint}\n")
    else:
        for param in model.parameters():
            param.data.uniform_(-opt.init_weight, opt.init_weight)

    optimizer = optim.Adam(model.parameters())
    optimizer = Optimizer(optimizer, max_grad_norm=opt.clip_grad)
    loss = nn.CrossEntropyLoss()
    model = model.to(device)
    loss = loss.to(device)

    if opt.phase == 'train':
        trans_data = TranslateData()
        train_set = Short_text_Dataset(opt.train_path,
                                       trans_data.translate_data,
                                       src_vocab,
                                       max_src_length=opt.max_src_length)
        train = DataLoader(train_set,
                           batch_size=opt.batch_size,
                           shuffle=False,
                           drop_last=True,
                           collate_fn=trans_data.collate_fn)
Пример #22
0
def train(args):
    """Run model training."""

    # Get nested namespaces.
    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args

    # Get logger.
    logger = Logger(logger_args)

    if model_args.ckpt_path:
        # CL-specified args are used to load the model, rather than the
        # ones saved to args.json.
        model_args.pretrained = False
        ckpt_path = model_args.ckpt_path
        assert False
        model, ckpt_info = ModelSaver.load_model(ckpt_path=ckpt_path,
                                                 gpu_ids=args.gpu_ids,
                                                 model_args=model_args,
                                                 is_training=True)
        optim_args.start_epoch = ckpt_info['epoch'] + 1
    else:
        # If no ckpt_path is provided, instantiate a new randomly
        # initialized model.
        model_fn = models.__dict__[model_args.model]
        model = model_fn(model_args)
        model = nn.DataParallel(model, args.gpu_ids)
    # Put model on gpu or cpu and put into training mode.
    model = model.to(args.device)
    model.train()

    # Get train and valid loader objects.
    train_loader = get_loader(phase="train",
                              data_args=data_args,
                              is_training=True,
                              logger=logger)
    valid_loader = get_loader(phase="valid",
                              data_args=data_args,
                              is_training=False,
                              logger=logger)
    dense_valid_loader = get_loader(phase="dense_valid",
                                    data_args=data_args,
                                    is_training=False,
                                    logger=logger)

    # Instantiate the predictor class for obtaining model predictions.
    predictor = Predictor(model, args.device)

    # Instantiate the evaluator class for evaluating models.
    # By default, get best performance on validation set.
    evaluator = Evaluator(logger=logger, tune_threshold=True)

    # Instantiate the saver class for saving model checkpoints.
    saver = ModelSaver(save_dir=logger_args.save_dir,
                       iters_per_save=logger_args.iters_per_save,
                       max_ckpts=logger_args.max_ckpts,
                       metric_name=optim_args.metric_name,
                       maximize_metric=optim_args.maximize_metric,
                       keep_topk=True,
                       logger=logger)

    # Instantiate the optimizer class for guiding model training.
    optimizer = Optimizer(parameters=model.parameters(),
                          optim_args=optim_args,
                          batch_size=data_args.batch_size,
                          iters_per_print=logger_args.iters_per_print,
                          iters_per_visual=logger_args.iters_per_visual,
                          iters_per_eval=logger_args.iters_per_eval,
                          dataset_len=len(train_loader.dataset),
                          logger=logger)
    if model_args.ckpt_path:
        # Load the same optimizer as used in the original training.
        optimizer.load_optimizer(ckpt_path=model_args.ckpt_path,
                                 gpu_ids=args.gpu_ids)

    loss_fn = evaluator.get_loss_fn(loss_fn_name=optim_args.loss_fn)

    # Run training
    while not optimizer.is_finished_training():
        optimizer.start_epoch()

        for inputs, targets in train_loader:
            optimizer.start_iter()

            if optimizer.global_step % optimizer.iters_per_eval == 0:
                # Only evaluate every iters_per_eval examples.
                predictions, groundtruth = predictor.predict(valid_loader)
                metrics = evaluator.evaluate(groundtruth, predictions)

                # Evaluate on dense dataset
                dense_predictions, dense_groundtruth = predictor.predict(
                    dense_valid_loader)
                dense_metrics = evaluator.dense_evaluate(
                    dense_groundtruth, dense_predictions)
                # Merge the metrics dicts together
                metrics = {**metrics, **dense_metrics}

                # Log metrics to stdout.
                logger.log_metrics(metrics, phase='valid')

                # Log to tb
                logger.log_scalars(metrics,
                                   optimizer.global_step,
                                   phase='valid')

                if optimizer.global_step % logger_args.iters_per_save == 0:
                    # Only save every iters_per_save examples directly
                    # after evaluation.
                    saver.save(iteration=optimizer.global_step,
                               epoch=optimizer.epoch,
                               model=model,
                               optimizer=optimizer,
                               device=args.device,
                               metric_val=metrics[optim_args.metric_name])

                # Step learning rate scheduler.
                optimizer.step_scheduler(metrics[optim_args.metric_name])

            with torch.set_grad_enabled(True):

                # Run the minibatch through the model.
                logits = model(inputs.to(args.device))

                # Compute the minibatch loss.
                loss = loss_fn(logits, targets.to(args.device))

                # Log the data from this iteration.
                optimizer.log_iter(inputs, logits, targets, loss)

                # Perform a backward pass.
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            optimizer.end_iter()

        optimizer.end_epoch(metrics)

    # Save the most recent model.
    saver.save(iteration=optimizer.global_step,
               epoch=optimizer.epoch,
               model=model,
               optimizer=optimizer,
               device=args.device,
               metric_val=metrics[optim_args.metric_name])
Пример #23
0
    if config.training.num_gpu > 0:
        model = model.cuda()
        if config.training.num_gpu > 1:
            device_ids = list(range(config.training.num_gpu))
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        logger.info('Loaded the model to %d GPUs' % config.training.num_gpu)

    n_params, enc, dec, fir_enc = count_parameters(model)
    logger.info('# the number of parameters in the whole model: %d' % n_params)
    logger.info('# the number of parameters in the Encoder: %d' % enc)
    logger.info('# the number of parameters in the Decoder: %d' % dec)
    logger.info('# the number of parameters in the fir_enc: %d' % fir_enc)
    logger.info('# the number of parameters in the JointNet: %d' %
                (n_params - dec - enc - fir_enc))
    optimizer = Optimizer(model.parameters(), config.optim)
    # optimizer = torch.optim.adam(model.parameters(),lr=config.optim.lr,betas=(0.9, 0.98),eps=1e-08,weight_decay=config.optim.weight_decay)
    logger.info('Created a %s optimizer.' % config.optim.type)
    if opt.mode == 'continue':
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        logger.info('From epcho:%d training! ' % start_epoch)
        logger.info('Load Optimizer State!')
    else:
        start_epoch = 0

    # create a visualizer
    if config.training.visualization:
        visualizer = SummaryWriter(os.path.join(exp_name, 'log'))
        logger.info('Created a visualizer.')
    else:
Пример #24
0
class SupervisedTrainer(object):
    def __init__(self, db):
        self.cfg = db.cfg
        self.db = db
        net = DrawModel(db, if_bert=True)

        if self.cfg.cuda:
            if self.cfg.parallel and torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                net = nn.DataParallel(net)
            net = net.cuda()
        self.net = net
        self.start_epoch = 0
        if self.cfg.pretrained is not None:
            self.load_pretrained_net(self.cfg.pretrained)
        self.writer = SummaryWriter(comment='Layout')
        self.xymap = CocoLocationMap(self.cfg)
        self.whmap = CocoTransformationMap(self.cfg)
        self.optimizer = None
        self.bert_optimizer = None
        self.bert_scheduler = None
        self.mse_loss = nn.MSELoss(reduction='sum')
        self.train_bucket_sampler = None
        self.val_bucket_sampler = None
        self.global_step = 0

    def get_parameter_number(self):
        net = self.net
        total_num = sum(p.numel() for p in net.parameters())
        trainable_num = sum(p.numel() for p in net.parameters()
                            if p.requires_grad)
        print('Total', total_num, 'Trainable', trainable_num)

    # def load_optimizer(self, pretrained_name):
    #     cache_dir = osp.join(self.cfg.data_dir, 'caches')
    #     pretrained_path = osp.join(cache_dir, 'bert_layout_ckpts', pretrained_name+'.pkl')
    #     bert_path = osp.join(cache_dir, 'bert_layout_ckpts', 'bert-'+pretrained_name+'.pkl')

    #     assert osp.exists(pretrained_path)
    #     if self.cfg.cuda:
    #         checkpoint = torch.load(pretrained_path)
    #         bert_checkpoint = torch.load(bert_path)
    #     else:
    #         checkpoint = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
    #         bert_checkpoint = torch.load(bert_path, map_location=lambda storage, loc: storage)
    #     self.optimizer.optimizer.load_state_dict(checkpoint['optimizer'])
    #     self.bert_optimizer.load_state_dict(bert_checkpoint['optimizer'])

    def load_pretrained_net(self, pretrained_name):
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        # cache_dir = osp.join(self.cfg.data_dir, 'caches')
        # pretrained_path = osp.join(cache_dir, 'bert_layout_ckpts', pretrained_name+'.pkl')
        pretrained_path = pretrained_name
        # self.begin_epoch = int(pretrained_name.split('-')[1]) + 1
        print('loading ckpt from ', pretrained_path)

        assert osp.exists(pretrained_path)
        if self.cfg.cuda:
            checkpoint = torch.load(pretrained_path)
        else:
            checkpoint = torch.load(pretrained_path,
                                    map_location=lambda storage, loc: storage)
        net.load_state_dict(checkpoint['net'])
        # self.optimizer.optimizer.load_state_dict(checkpoint['optimizer'])
        self.start_epoch = checkpoint['epoch'] + 1
        print('Start training from {} epoch'.format(self.start_epoch))

    def batch_data(self, entry):
        ################################################
        # Inputs
        maxlen = max(entry['obj_cnt']).item()
        # print(entry['obj_cnt'])
        input_inds = entry['bert_inds'].long()
        input_lens = entry['bert_lens'].unsqueeze(-1).long()
        fg_inds = entry['fg_inds'].long()
        bg_imgs = entry['background'].float()
        fg_onehots = indices2onehots(fg_inds, self.cfg.output_cls_size)

        ################################################
        # Outputs
        ################################################
        gt_inds = entry['out_inds'].long()
        gt_msks = entry['out_msks'].float()
        gt_scene_inds = entry['scene_idx'].long().numpy()
        gt_boxes = entry['boxes'].float()
        box_msks = entry['box_msks'].float()

        ################################################
        # print(entry['obj_cnt'][0])
        # print(fg_onehots[0])
        # print(gt_inds[0])

        if self.cfg.cuda:
            input_inds = input_inds.cuda()
            input_lens = input_lens.cuda()
            fg_onehots = fg_onehots[:, :maxlen + 1, :].cuda()  #(bsize,11,83)
            bg_imgs = bg_imgs[:, :maxlen, :, :, :].cuda(
            )  #(bsize, 10, 83, 64, 64)
            gt_inds = gt_inds[:, :maxlen, :].cuda()  #(bsize, 10,4)
            gt_msks = gt_msks[:, :maxlen, :].cuda()  #(bsize, 10,4)
            gt_boxes = gt_boxes[:, :maxlen, :].cuda()  #(bsize, 10,5)
            box_msks = box_msks[:, :maxlen, :].cuda()  #(bsize, 10,5)
        # print(input_inds.shape, input_lens.shape, fg_onehots.shape, bg_imgs.shape,
        # gt_inds.shape, gt_msks.shape, gt_boxes.shape, box_msks.shape)

        return input_inds, input_lens, bg_imgs, fg_onehots, gt_inds, gt_msks, gt_scene_inds, gt_boxes, box_msks

    def evaluate(self, inf_outs, ref_inds, ref_msks, ref_boxes, box_msks):
        ####################################################################
        # Prediction loss
        ####################################################################
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        if not self.cfg.mse:
            _, _, _, enc_msks, what_wei, where_wei = inf_outs
        else:
            obj_logits, pred_boxes, enc_msks, what_wei, where_wei = inf_outs

        ####################################################################
        # doubly stochastic attn loss
        ####################################################################
        attn_loss = 0
        encoder_msks = enc_msks

        if self.cfg.what_attn:
            obj_msks = ref_msks[:, :, 0].unsqueeze(-1)
            what_att_logits = what_wei
            raw_obj_att_loss = torch.mul(what_att_logits, obj_msks)
            raw_obj_att_loss = torch.sum(raw_obj_att_loss, dim=1)
            obj_att_loss = raw_obj_att_loss - encoder_msks
            obj_att_loss = torch.sum(obj_att_loss**2, dim=-1)
            obj_att_loss = torch.mean(obj_att_loss)
            attn_loss = attn_loss + obj_att_loss

        attn_loss = self.cfg.attn_loss_weight * attn_loss

        eos_loss = 0
        if self.cfg.what_attn and self.cfg.eos_loss_weight > 0:
            # print('-------------------')
            # print('obj_msks: ', obj_msks.size())
            inds_1 = torch.sum(obj_msks, 1, keepdim=True) - 1
            # print('inds_1: ', inds_1.size())
            bsize, tlen, slen = what_att_logits.size()
            # print('what_att_logits: ', what_att_logits.size())
            inds_1 = inds_1.expand(bsize, 1, slen).long()
            local_eos_probs = torch.gather(what_att_logits, 1,
                                           inds_1).squeeze(1)
            # print('local_eos_probs: ', local_eos_probs.size())
            # print('encoder_msks: ', encoder_msks.size())
            inds_2 = torch.sum(encoder_msks, 1, keepdim=True) - 1
            # print('inds_2: ', inds_2.size())
            eos_probs = torch.gather(local_eos_probs, 1, inds_2.long())
            norm_probs = torch.gather(raw_obj_att_loss, 1, inds_2.long())
            # print('norm_probs:', norm_probs.size())
            # print('eos_probs: ', eos_probs.size())
            eos_loss = -torch.log(eos_probs.clamp(min=self.cfg.eps))
            eos_loss = torch.mean(eos_loss)
            diff = torch.sum(norm_probs) - 1.0
            norm_loss = diff * diff
            # print('obj_att_loss: ', att_loss)
            # print('eos_loss: ', eos_loss)
            # print('norm_loss: ', norm_loss)
        eos_loss = self.cfg.eos_loss_weight * eos_loss

        # torch.cuda.synchronize()
        # s = time()

        if not self.cfg.mse:
            # _, _, _, enc_msks, what_wei, where_wei = inf_outs

            logits = net.collect_logits(inf_outs, ref_inds)
            bsize, slen, _ = logits.size()
            loss_wei = [
                self.cfg.obj_loss_weight, \
                self.cfg.coord_loss_weight, \
                self.cfg.scale_loss_weight, \
                self.cfg.ratio_loss_weight
            ]
            loss_wei = torch.from_numpy(np.array(loss_wei)).float()
            if self.cfg.cuda:
                loss_wei = loss_wei.cuda()
            loss_wei = loss_wei.view(1, 1, 4)
            loss_wei = loss_wei.expand(bsize, slen, 4)

            pred_loss = -torch.log(
                logits.clamp(min=self.cfg.eps)) * loss_wei * ref_msks
            pred_loss = torch.sum(pred_loss) / (torch.sum(ref_msks) +
                                                self.cfg.eps)

            ####################################################################
            # Accuracies
            ####################################################################
            pred_accu, pred_mse = net.collect_accuracies(inf_outs, ref_inds)
            pred_accu = pred_accu * ref_msks

            mse_msks = ref_msks[:, :, -1].unsqueeze(-1).expand(-1, -1, 4)
            pred_mse = pred_mse * mse_msks

            comp_accu = torch.sum(torch.sum(pred_accu, 0), 0)
            comp_msks = torch.sum(torch.sum(ref_msks, 0), 0)
            pred_accu = comp_accu / (comp_msks + self.cfg.eps)

            comp_mse = torch.sum(torch.sum(pred_mse, 0), 0)
            comp_msks = torch.sum(torch.sum(mse_msks, 0), 0)
            pred_mse = comp_mse / (comp_msks + self.cfg.eps)

        else:
            #inf_outs = (obj_logits, pred_boxes, enc_msks, what_wei, where_wei)
            # obj_logits, pred_boxes, enc_msks, what_wei, where_wei = inf_outs
            b_size, tlen, _ = pred_boxes.size()
            ref_boxes = ref_boxes * box_msks

            # logits = net.collect_logits(inf_outs, ref_inds)
            obj_inds = ref_inds[:, :, 0].unsqueeze(-1)
            sample_obj_logits = torch.gather(obj_logits, -1, obj_inds)

            obj_loss = -torch.log(sample_obj_logits.clamp(
                min=self.cfg.eps)) * ref_msks[:, :, 0].unsqueeze(-1)
            obj_loss = torch.sum(obj_loss) / (torch.sum(ref_msks[:, :, 0]) +
                                              self.cfg.eps)

            # obj_logits = obj_logits.float() * box_msks[:,:,0].unsqueeze(-1).expand(-1,-1,83)
            pred_boxes = pred_boxes * box_msks[:, :, 1:]

            gt_obj = ref_boxes[:, :, 0].unsqueeze(-1).long()
            gt_boxes = ref_boxes[:, :, 1:]

            # loss_wei = [
            #     self.cfg.obj_loss_weight, \
            #     self.cfg.coord_loss_weight, \
            #     self.cfg.scale_loss_weight, \
            #     self.cfg.ratio_loss_weight
            # ]
            # loss_wei = torch.from_numpy(np.array(loss_wei)).float()
            # if self.cfg.cuda:
            #     loss_wei = loss_wei.cuda()
            # loss_wei = loss_wei.view(1,1,4)
            # loss_wei = loss_wei.expand(bsize, slen, 4)

            # torch.cuda.synchronize()
            # print(time()-s)
            # torch.cuda.synchronize()
            # s = time()

            all_mse = (pred_boxes - gt_boxes)**2

            comp_mse = torch.sum(torch.sum(all_mse, 0), 0)
            comp_msks = torch.sum(torch.sum(box_msks[:, :, 1:], 0), 0)
            pred_mse = comp_mse / (comp_msks + self.cfg.eps)
            coord_loss = torch.sum(pred_mse)

            pred_loss = obj_loss + coord_loss
            # pred_loss = obj_loss + x_loss + y_loss + w_loss + h_loss
            # torch.cuda.synchronize()
            # print(time()-s)
            # torch.cuda.synchronize()
            # s = time()

            ####################################################################
            # Accuracies
            ####################################################################
            coord = self.db.boxes2indices(pred_boxes.contiguous().view(
                -1, 4).detach().cpu().numpy())
            coord = torch.from_numpy(coord).view(b_size, tlen, -1).cuda()

            _, pred_obj_inds = torch.max(obj_logits, -1)
            obj_accu = torch.eq(pred_obj_inds,
                                ref_inds[:, :, 0]).float().unsqueeze(-1)
            coord_accu = torch.eq(coord, ref_inds[:, :, 1:]).float()

            # pred_accu, pred_mse = net.collect_accuracies(inf_outs, ref_inds)
            pred_accu = torch.cat([obj_accu, coord_accu], -1)
            pred_accu = pred_accu * ref_msks

            comp_accu = torch.sum(torch.sum(pred_accu, 0), 0)
            comp_msks = torch.sum(torch.sum(ref_msks, 0), 0)
            pred_accu = comp_accu / (comp_msks + self.cfg.eps)

            # pred_mse = torch.stack([x_loss, y_loss, w_loss, h_loss], -1)
            # torch.cuda.synchronize()
            # print(time()-s)
            # print('====================')

        return pred_loss, attn_loss, eos_loss, pred_accu, pred_mse

    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## Optimizer
        ##################################################################
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        image_encoder_trainable_paras = \
            filter(lambda p: p.requires_grad, net.image_encoder.parameters())
        # raw_optimizer = optim.Adam([
        #         {'params': net.text_encoder.parameters(), 'lr': self.cfg.finetune_lr},
        #         {'params': image_encoder_trainable_paras},
        #         {'params': net.what_decoder.parameters()},
        #         {'params': net.where_decoder.parameters()}
        #     ], lr=self.cfg.lr)
        raw_optimizer = optim.Adam([{
            'params': image_encoder_trainable_paras,
            'initial_lr': self.cfg.lr
        }, {
            'params': net.what_decoder.parameters(),
            'initial_lr': self.cfg.lr
        }, {
            'params': net.where_decoder.parameters(),
            'initial_lr': self.cfg.lr
        }],
                                   lr=self.cfg.lr)
        self.optimizer = Optimizer(raw_optimizer,
                                   max_grad_norm=self.cfg.grad_norm_clipping)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, factor=0.8, patience=3)
        scheduler = optim.lr_scheduler.StepLR(self.optimizer.optimizer,
                                              step_size=3,
                                              gamma=0.8,
                                              last_epoch=self.start_epoch - 1)
        self.optimizer.set_scheduler(scheduler)

        num_train_steps = int(
            len(train_db) / self.cfg.accumulation_steps * self.cfg.n_epochs)
        num_warmup_steps = int(num_train_steps * self.cfg.warmup)
        self.bert_optimizer = AdamW([{
            'params': net.text_encoder.parameters(),
            'initial_lr': self.cfg.finetune_lr
        }],
                                    lr=self.cfg.finetune_lr)
        self.bert_scheduler = get_linear_schedule_with_warmup(
            self.bert_optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_train_steps,
            last_epoch=self.start_epoch - 1)

        bucket_boundaries = [4, 8, 12, 16, 22]  # [4,8,12,16,22]
        print('preparing training bucket sampler')
        self.train_bucket_sampler = BucketSampler(
            train_db, bucket_boundaries, batch_size=self.cfg.batch_size)
        print('preparing validation bucket sampler')
        self.val_bucket_sampler = BucketSampler(val_db,
                                                bucket_boundaries,
                                                batch_size=4)

        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()

        for epoch in range(self.start_epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            print('Training...')
            torch.cuda.empty_cache()
            train_pred_loss, train_attn_loss, train_eos_loss, train_accu, train_mse = \
                self.train_epoch(train_db, self.optimizer, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            print('Validation...')
            val_loss, val_accu, val_mse, val_infos = self.validate_epoch(
                val_db)

            ##################################################################
            ## Sample
            ##################################################################
            if self.cfg.if_sample:
                print('Sample...')
                torch.cuda.empty_cache()
                self.sample(epoch, test_db, self.cfg.n_samples)
                torch.cuda.empty_cache()
            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            print('Loging...')
            self.optimizer.update(np.mean(val_loss), epoch)

            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)

            logz.log_tabular("TrainAverageError", np.mean(train_pred_loss))
            logz.log_tabular("TrainAverageAccu", np.mean(train_accu))
            logz.log_tabular("TrainAverageMse", np.mean(train_mse))
            logz.log_tabular("ValAverageError", np.mean(val_loss))
            logz.log_tabular("ValAverageAccu", np.mean(val_accu))
            logz.log_tabular("ValAverageObjAccu", np.mean(val_accu[:, 0]))
            logz.log_tabular("ValAverageCoordAccu", np.mean(val_accu[:, 1]))
            logz.log_tabular("ValAverageScaleAccu", np.mean(val_accu[:, 2]))
            logz.log_tabular("ValAverageRatioAccu", np.mean(val_accu[:, 3]))
            logz.log_tabular("ValAverageMse", np.mean(val_mse))
            logz.log_tabular("ValAverageXMse", np.mean(val_mse[:, 0]))
            logz.log_tabular("ValAverageYMse", np.mean(val_mse[:, 1]))
            logz.log_tabular("ValAverageWMse", np.mean(val_mse[:, 2]))
            logz.log_tabular("ValAverageHMse", np.mean(val_mse[:, 3]))
            logz.log_tabular("ValUnigramF3", np.mean(val_infos.unigram_F3()))
            logz.log_tabular("ValBigramF3", np.mean(val_infos.bigram_F3()))
            logz.log_tabular("ValUnigramP", np.mean(val_infos.unigram_P()))
            logz.log_tabular("ValUnigramR", np.mean(val_infos.unigram_R()))
            logz.log_tabular("ValBigramP", val_infos.mean_bigram_P())
            logz.log_tabular("ValBigramR", val_infos.mean_bigram_R())
            logz.log_tabular("ValUnigramScale", np.mean(val_infos.scale()))
            logz.log_tabular("ValUnigramRatio", np.mean(val_infos.ratio()))
            logz.log_tabular("ValUnigramSim",
                             np.mean(val_infos.unigram_coord()))
            logz.log_tabular("ValBigramSim", val_infos.mean_bigram_coord())

            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            print('Saving checkpoint...')
            log_info = [np.mean(val_loss), np.mean(val_accu)]
            self.save_checkpoint(epoch, log_info)
            torch.cuda.empty_cache()

    def train_epoch(self, train_db, optimizer, epoch):
        train_db.cfg.sent_group = -1

        train_loader = DataLoader(train_db,
                                  batch_size=1,
                                  num_workers=self.cfg.num_workers,
                                  batch_sampler=self.train_bucket_sampler,
                                  drop_last=False,
                                  pin_memory=False)

        train_pred_loss, train_attn_loss, train_eos_loss, train_accu, train_mse = [], [], [], [], []

        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        get_data = 0.
        forward = 0.
        eva = 0.
        back = 0.
        opt = 0.
        start = time()
        for cnt, batched in tqdm(enumerate(train_loader)):
            ##################################################################
            ## Batched data
            #############################F#####################################
            # torch.cuda.synchronize()
            # s = time()
            input_sentences, input_lens, bg_imgs, fg_onehots, \
            gt_inds, gt_msks, gt_scene_inds, boxes, box_msks = \
                self.batch_data(batched)
            gt_scenes = [deepcopy(train_db.scenedb[x]) for x in gt_scene_inds]

            ##################################################################
            ## Train one step
            ##################################################################
            self.net.train()
            # torch.cuda.synchronize()
            # get_data += time()-s
            # torch.cuda.synchronize()
            # s = time()

            if self.cfg.teacher_forcing:
                # self.get_parameter_number()
                # net.get_parameter_number()
                inputs = (input_sentences, input_lens, bg_imgs, fg_onehots)
                # inputs = (torch.zeros(8, 25).long().cuda(), torch.zeros(8).long().cuda(), torch.zeros(8, 10, 83, 64, 64).float().cuda(), torch.zeros(8,11,83).float().cuda())
                # self.writer.add_graph(self.net, inputs, True)
                # sys.exit()
                inf_outs, _ = self.net(inputs)
            else:
                inf_outs, _ = net.inference(input_sentences, input_lens, -1,
                                            -0.1, 0, gt_inds)

            # torch.cuda.synchronize()
            # forward += time()-s
            # torch.cuda.synchronize()
            # s = time()

            pred_loss, attn_loss, eos_loss, pred_accu, pred_mse = self.evaluate(
                inf_outs, gt_inds, gt_msks, boxes, box_msks)
            # torch.cuda.synchronize()
            # eva += time()-s
            # torch.cuda.synchronize()
            # s = time()

            loss = pred_loss + attn_loss + eos_loss
            loss = loss / self.cfg.accumulation_steps

            loss.backward()
            # torch.cuda.synchronize()
            # back += time()-s
            # torch.cuda.synchronize()
            # s = time()

            if ((cnt + 1) % self.cfg.accumulation_steps) == 0:
                self.optimizer.step()
                self.bert_optimizer.step()
                self.bert_scheduler.step()
                self.net.zero_grad()
                self.global_step += 1

                # torch.cuda.synchronize()
                # opt += time()-s

                ##################################################################
                ## Collect info
                ##################################################################
                train_pred_loss.append(pred_loss.cpu().data.item())
                if attn_loss == 0:
                    attn_loss_np = 0
                else:
                    attn_loss_np = attn_loss.cpu().data.item()
                train_attn_loss.append(attn_loss_np)
                if eos_loss == 0:
                    eos_loss_np = 0
                else:
                    eos_loss_np = eos_loss.cpu().data.item()
                train_eos_loss.append(eos_loss_np)
                train_accu.append(pred_accu.cpu().data.numpy())
                train_mse.append(pred_mse.cpu().data.numpy())

                ##################################################################
                ## Print info
                ##################################################################
                if self.global_step % self.cfg.log_per_steps == 0:
                    print('Epoch %03d, iter %07d:' % (epoch, cnt))
                    print('loss: ', np.mean(train_pred_loss),
                          np.mean(train_attn_loss), np.mean(train_eos_loss))
                    print('accu: ', np.mean(np.array(train_accu), 0))
                    print('xy & wh mse: ', np.mean(np.array(train_mse), 0))
                    torch.cuda.synchronize()
                    print('-------------------------')

        return train_pred_loss, train_attn_loss, train_eos_loss, train_accu, train_mse

    def validate_epoch(self, val_db):
        val_loss, val_accu, val_mse, top1_scores = [], [], [], []
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        for G in range(1):
            val_db.cfg.sent_group = G

            val_loader = DataLoader(val_db,
                                    batch_size=1,
                                    num_workers=self.cfg.num_workers,
                                    batch_sampler=self.val_bucket_sampler,
                                    drop_last=False,
                                    pin_memory=False)

            for cnt, batched in tqdm(enumerate(val_loader)):
                ##################################################################
                ## Batched data
                ##################################################################
                input_inds, input_lens, bg_imgs, fg_onehots, \
                gt_inds, gt_msks, gt_scene_inds, boxes, box_msks = \
                    self.batch_data(batched)
                gt_scenes = [
                    deepcopy(val_db.scenedb[x]) for x in gt_scene_inds
                ]

                ##################################################################
                ## Validate one step
                ##################################################################
                self.net.eval()
                with torch.no_grad():
                    _, env = net.inference(input_inds, input_lens, -1, 2.0, 0,
                                           None)
                    # infos = env.batch_evaluation(gt_inds.cpu().data.numpy())
                    scores = env.batch_evaluation(gt_scenes)
                    # scores = np.stack(scores, 0)
                    # infos = eval_info(self.cfg, scores)
                    inputs = (input_inds, input_lens, bg_imgs, fg_onehots)
                    inf_outs, _ = self.net(inputs)
                    # inf_outs, _ = self.net.teacher_forcing(input_inds, input_lens, bg_imgs, fg_onehots)
                    # print('gt_inds', gt_inds)
                    pred_loss, attn_loss, eos_loss, pred_accu, pred_mse = self.evaluate(
                        inf_outs, gt_inds, gt_msks, boxes,
                        box_msks)  #self.evaluate(inf_outs, gt_inds, gt_msks)

                top1_scores.extend(scores)
                val_loss.append(pred_loss.cpu().data.item())
                val_accu.append(pred_accu.cpu().data.numpy())
                val_mse.append(pred_mse.cpu().data.numpy())

                # print(G, cnt)
                # print('pred_loss', pred_loss.data.item())
                # print('pred_accu', pred_accu)
                # print('scores', scores)
                # if cnt > 0:
                #     break

        top1_scores = np.stack(top1_scores, 0)
        val_loss = np.array(val_loss)
        val_accu = np.stack(val_accu, 0)
        val_mse = np.stack(val_mse, 0)
        infos = eval_info(self.cfg, top1_scores)

        return val_loss, val_accu, val_mse, infos

    def sample(self, epoch, test_db, N, random_or_not=False):
        ##############################################################
        # Output prefix
        ##############################################################
        output_dir = osp.join(self.cfg.model_dir, '%03d' % epoch, 'vis')
        pred_dir = osp.join(self.cfg.model_dir, '%03d' % epoch, 'pred')
        gt_dir = osp.join(self.cfg.model_dir, '%03d' % epoch, 'gt')
        img_dir = osp.join(self.cfg.model_dir, '%03d' % epoch, 'color')

        maybe_create(output_dir)
        maybe_create(pred_dir)
        maybe_create(gt_dir)
        maybe_create(img_dir)

        ##############################################################
        # Main loop
        ##############################################################
        plt.switch_backend('agg')
        if random_or_not:
            indices = np.random.permutation(range(len(test_db)))
        else:
            indices = range(len(test_db))
        indices = indices[:N]

        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        for i in indices:
            entry = test_db[i]
            gt_scene = test_db.scenedb[i]

            gt_img = cv2.imread(entry['color_path'], cv2.IMREAD_COLOR)
            gt_img, _, _ = create_squared_image(gt_img)
            gt_img = cv2.resize(gt_img,
                                (self.cfg.draw_size[0], self.cfg.draw_size[1]))

            ##############################################################
            # Inputs
            ##############################################################
            # input_inds_np = np.array(entry['word_inds'])
            sentence = np.array(entry['bert_inds'])
            input_lens_np = np.array(entry['bert_lens'])

            # input_inds = torch.from_numpy(input_inds_np).long().unsqueeze(0)
            input_lens = torch.from_numpy(input_lens_np).long().unsqueeze(0)
            if self.cfg.cuda:
                input_inds = input_inds.cuda()
                input_lens = input_lens.cuda()

            ##############################################################
            # Inference
            ##############################################################
            self.net.eval()
            with torch.no_grad():
                inf_outs, env = net.inference(input_inds, input_lens, -1, 2.0,
                                              0, None)
            frames = env.batch_redraw(return_sequence=True)[0]
            # _, _, _, _, what_wei, where_wei = inf_outs
            what_wei, where_wei = inf_outs[-2:]

            if self.cfg.what_attn:
                what_attn_words = self.decode_attention(
                    input_inds_np, input_lens_np, what_wei.squeeze(0))
            if self.cfg.where_attn > 0:
                where_attn_words = self.decode_attention(
                    input_inds_np, input_lens_np, where_wei.squeeze(0))

            ##############################################################
            # Draw
            ##############################################################
            fig = plt.figure(figsize=(60, 30))
            plt.suptitle(entry['sentence'], fontsize=50)
            for j in range(frames.shape[0]):
                # print(attn_words[j])
                subtitle = ''
                if self.cfg.what_attn:
                    subtitle = subtitle + ' '.join(what_attn_words[j])
                if self.cfg.where_attn > 0:
                    subtitle = subtitle + '\n' + ' '.join(where_attn_words[j])

                plt.subplot(4, 4, j + 1)
                plt.title(subtitle, fontsize=30)
                plt.imshow(frames[j, :, :, ::-1])
                plt.axis('off')
            plt.subplot(4, 4, 16)
            plt.imshow(gt_img[:, :, ::-1])
            plt.axis('off')

            name = osp.splitext(osp.basename(entry['color_path']))[0]
            out_path = osp.join(output_dir, name + '.png')
            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

            cv2.imwrite(osp.join(pred_dir, name + '.png'), frames[-1])
            cv2.imwrite(osp.join(img_dir, name + '.png'), gt_img)
            gt_layout = self.db.render_scene_as_output(gt_scene, False, gt_img)
            cv2.imwrite(osp.join(gt_dir, name + '.png'), gt_layout)
            print('sampling: %d, %d' % (epoch, i))

    def show_metric(self, epoch, test_db, N, random_or_not=False):
        ##############################################################
        # Output prefix
        ##############################################################
        output_dir = osp.join(self.cfg.model_dir, '%03d' % epoch, 'metric')
        maybe_create(output_dir)
        ##############################################################
        # Main loop
        ##############################################################
        plt.switch_backend('agg')
        if random_or_not:
            indices = np.random.permutation(range(len(test_db)))
        else:
            indices = range(len(test_db))
        indices = indices[:N]
        test_db.cfg.sent_group = 1

        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        ev = evaluator(self.db)

        for i in indices:
            entry = test_db[i]
            gt_scene = test_db.scenedb[i]
            scene_idx = int(gt_scene['img_idx'])
            name = osp.splitext(osp.basename(entry['color_path']))[0]

            ##############################################################
            # Inputs
            ##############################################################
            input_inds_np = np.array(entry['word_inds'])
            input_lens_np = np.array(entry['word_lens'])

            input_inds = torch.from_numpy(input_inds_np).long().unsqueeze(0)
            input_lens = torch.from_numpy(input_lens_np).long().unsqueeze(0)
            if self.cfg.cuda:
                input_inds = input_inds.cuda()
                input_lens = input_lens.cuda()

            ##############################################################
            # Inference
            ##############################################################
            self.net.eval()
            with torch.no_grad():
                inf_outs, env = self.net.inference(input_inds, input_lens, -1,
                                                   2.0, 0, None)
            frame = env.batch_redraw(return_sequence=False)[0][0]
            raw_pred_scene = env.scenes[0]

            pred_inds = deepcopy(raw_pred_scene['out_inds'])
            pred_inds = np.stack(pred_inds, 0)
            pred_scene = self.db.output_inds_to_scene(pred_inds)

            graph_1 = scene_graph(self.db, pred_scene, None, False)
            graph_2 = scene_graph(self.db, gt_scene, None, False)

            color_1 = frame.copy()
            gt_img = cv2.imread(entry['color_path'], cv2.IMREAD_COLOR)
            gt_img, _, _ = create_squared_image(gt_img)
            gt_img = cv2.resize(gt_img,
                                (self.cfg.draw_size[0], self.cfg.draw_size[1]))
            color_2 = gt_img

            cv2.imwrite('%09d_b.png' % i, color_1)
            cv2.imwrite('%09d_i.png' % i, color_2)

            color_1 = visualize_unigram(self.cfg, color_1, graph_1.unigrams,
                                        (225, 0, 0))
            color_2 = visualize_unigram(self.cfg, color_2, graph_2.unigrams,
                                        (225, 0, 0))
            color_1 = visualize_bigram(self.cfg, color_1, graph_1.bigrams,
                                       (0, 0, 255))
            color_2 = visualize_bigram(self.cfg, color_2, graph_2.bigrams,
                                       (0, 0, 255))

            scores = ev.evaluate_graph(graph_1, graph_2)

            color_1 = visualize_unigram(self.cfg, color_1,
                                        ev.common_pred_unigrams, (0, 225, 0))
            color_2 = visualize_unigram(self.cfg, color_2,
                                        ev.common_gt_unigrams, (0, 225, 0))
            color_1 = visualize_bigram(self.cfg, color_1,
                                       ev.common_pred_bigrams, (0, 255, 255))
            color_2 = visualize_bigram(self.cfg, color_2, ev.common_gt_bigrams,
                                       (0, 255, 255))

            info = eval_info(self.cfg, scores[None, ...])

            plt.switch_backend('agg')
            fig = plt.figure(figsize=(16, 10))
            title = entry['sentence']
            title += 'UR:%f,UP:%f,BR:%f,BP:%f\n' % (
                info.unigram_R()[0], info.unigram_P()[0], info.bigram_R()[0],
                info.bigram_P()[0])
            title += 'scale: %f, ratio: %f, coord: %f, b:%f \n' % (
                info.scale()[0], info.ratio()[0], info.unigram_coord()[0],
                info.bigram_coord()[0])

            plt.suptitle(title)
            plt.subplot(1, 2, 1)
            plt.imshow(color_1[:, :, ::-1])
            plt.axis('off')
            plt.subplot(1, 2, 2)
            plt.imshow(color_2[:, :, ::-1])
            plt.axis('off')

            out_path = osp.join(output_dir, name + '.png')
            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

    def sample_all_top1(self, test_db):
        ##############################################################
        # Output prefix
        ##############################################################
        out_dir = osp.join(self.cfg.model_dir, 'top1_scenes')
        maybe_create(out_dir)
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        indices = range(len(test_db))
        scores = []
        for G in range(5):
            test_db.cfg.sent_group = G
            G_dir = osp.join(out_dir, '%02d' % G)
            maybe_create(G_dir)
            img_dir = osp.join(G_dir, 'images')
            maybe_create(img_dir)
            scene_dir = osp.join(G_dir, 'scenes')
            maybe_create(scene_dir)

            for i in indices:
                entry = test_db[i]
                gt_scene = test_db.scenedb[i]
                ##############################################################
                # Inputs
                ##############################################################
                input_inds_np = np.array(entry['word_inds'])
                input_lens_np = np.array(entry['word_lens'])

                input_inds = torch.from_numpy(input_inds_np).long().unsqueeze(
                    0)
                input_lens = torch.from_numpy(input_lens_np).long().unsqueeze(
                    0)
                if self.cfg.cuda:
                    input_inds = input_inds.cuda()
                    input_lens = input_lens.cuda()

                ##############################################################
                # Inference
                ##############################################################
                self.net.eval()
                with torch.no_grad():
                    inf_outs, env = net.inference(input_inds, input_lens, -1,
                                                  2.0, 0, None)

                # for j in range(len(env.scenes)):
                #     bar_scene = env.scenes[j]
                #     bar_inds = bar_scene['out_inds']
                #     if bar_inds[0][0] > self.cfg.EOS_idx:
                #         break
                # if bar_inds[0][0] <= self.cfg.EOS_idx:
                #     continue
                # pred_scene = deepcopy(bar_scene)
                pred_scene = env.scenes[0]

                ##############################################################
                # Evaluate
                ##############################################################
                pred_scores = env.evaluate_scene(pred_scene, gt_scene)
                scores.append(pred_scores)

                ##############################################################
                # Draw
                ##############################################################
                frame = env.batch_redraw(return_sequence=False)[0][0]

                ##############################################################
                # Save scene
                ##############################################################
                pred_inds = deepcopy(pred_scene['out_inds'])
                pred_inds = np.stack(pred_inds, 0)
                foo = test_db.output_inds_to_scene(pred_inds)
                out_scene = {}
                out_scene['boxes'] = [
                    x.tolist() for x in foo['boxes'].astype(np.float64)
                ]
                out_scene['clses'] = foo['clses'].astype(np.int64).tolist()
                out_scene['caption'] = entry['sentence']
                out_scene['width'] = int(gt_scene['width'])
                out_scene['height'] = int(gt_scene['height'])
                img_idx = int(gt_scene['img_idx'])
                out_scene['img_idx'] = img_idx
                scene_path = osp.join(
                    scene_dir, '%02d_' % G + str(img_idx).zfill(12) + '.json')
                img_path = osp.join(
                    img_dir, '%02d_' % G + str(img_idx).zfill(12) + '.jpg')
                cv2.imwrite(img_path, frame)
                with open(scene_path, 'w') as fp:
                    json.dump(out_scene, fp, indent=4, sort_keys=True)
                print(G, i, img_idx)

                # if len(scores) > 5:
                #     break

        scores = np.stack(scores, 0).astype(np.float64)
        infos = eval_info(self.cfg, scores)
        info_path = osp.join(out_dir, 'eval_info_top1.json')
        log_coco_scores(infos, info_path)

    def sample_demo(self, input_sentences):
        output_dir = osp.join(self.cfg.model_dir, 'bert_layout_samples')
        print(output_dir)
        maybe_create(output_dir)
        num_sents = len(input_sentences)
        ##############################################################
        # Main loop
        ##############################################################
        plt.switch_backend('agg')
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        for i in range(num_sents):
            # for i in range(5):
            sentence = input_sentences[i]
            ##############################################################
            # Inputs
            ##############################################################
            word_inds, word_lens = self.db.encode_sentence(sentence)
            input_inds_np = np.array(word_inds)
            input_lens_np = np.array(word_lens)
            input_inds = torch.from_numpy(input_inds_np).long().unsqueeze(0)
            input_lens = torch.from_numpy(input_lens_np).long().unsqueeze(0)
            if self.cfg.cuda:
                input_inds = input_inds.cuda()
                input_lens = input_lens.cuda()
            ##############################################################
            # Inference
            ##############################################################
            self.net.eval()
            with torch.no_grad():
                inf_outs, env = net.inference(input_inds, input_lens, -1, 2.0,
                                              0, None)
            frames = env.batch_redraw(return_sequence=True)[0]

            _, objs = torch.max(inf_outs[0], -1)
            objs = objs[0].cpu().data
            print('------------{}------------'.format(i))
            for k in range(len(objs)):
                if objs[k] <= self.cfg.EOS_idx:
                    break

                print(self.db.classes[objs[k]])
                print(inf_outs[1][0][k])
            # _, _, _, _, what_wei, where_wei = inf_outs
            # if self.cfg.what_attn:
            #     what_attn_words = self.decode_attention(
            #         input_inds_np, input_lens_np, what_wei.squeeze(0))
            # if self.cfg.where_attn > 0:
            #     where_attn_words = self.decode_attention(
            #         input_inds_np, input_lens_np, where_wei.squeeze(0))
            ##############################################################
            # Draw
            ##############################################################
            fig = plt.figure(figsize=(60, 40))
            plt.suptitle(sentence, fontsize=40)
            for j in range(frames.shape[0]):
                # subtitle = ''
                # if self.cfg.what_attn:
                #     subtitle = subtitle + ' '.join(what_attn_words[j])
                # if self.cfg.where_attn > 0:
                #     subtitle = subtitle + '\n' + ' '.join(where_attn_words[j])
                plt.subplot(4, 3, j + 1)
                # plt.title(subtitle, fontsize=30)
                plt.imshow(frames[j, :, :, ::-1])
                plt.axis('off')
            out_path = osp.join(output_dir, '%09d.png' % i)
            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

    def decode_attention(self, word_inds, word_lens, att_logits):
        _, att_inds = torch.topk(att_logits, 3, -1)
        att_inds = att_inds.cpu().data.numpy()

        if len(word_inds.shape) > 1:
            lin_inds = []
            for i in range(word_inds.shape[0]):
                lin_inds.extend(word_inds[i, :word_lens[i]].tolist())
            vlen = len(lin_inds)
            npad = self.cfg.max_input_length * 3 - vlen
            lin_inds = lin_inds + [0] * npad
            # print(lin_inds)
            lin_inds = np.array(lin_inds).astype(np.int32)
        else:
            lin_inds = word_inds.copy()

        slen, _ = att_inds.shape
        attn_words = []
        for i in range(slen):
            w_inds = [lin_inds[x] for x in att_inds[i]]
            w_strs = [self.db.lang_vocab.index2word[x] for x in w_inds]
            attn_words = attn_words + [w_strs]

        return attn_words

    def save_checkpoint(self, epoch, log):
        print(" [*] Saving checkpoints...")
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        checkpoint_dir = osp.join(self.cfg.model_dir, 'bert_layout_ckpts')
        if not osp.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        model_name = "ckpt-%03d-%.4f-%.4f.pkl" % (epoch, log[0], log[1])
        bert_name = "bert-ckpt-%03d-%.4f-%.4f.pkl" % (epoch, log[0], log[1])

        print('saving ckpt to ', checkpoint_dir)
        state = {
            'net': net.state_dict(),
            'optimizer': self.optimizer.optimizer.state_dict(),
            'epoch': epoch
        }
        # torch.save(net.state_dict(), osp.join(checkpoint_dir, model_name))
        torch.save(state, osp.join(checkpoint_dir, model_name))

        print('saving bert to ', checkpoint_dir)
        bert_state = {
            'net': net.text_encoder.embedding.state_dict(),
            'optimizer': self.bert_optimizer.state_dict(),
            'epoch': epoch
        }
        torch.save(bert_state, osp.join(checkpoint_dir, bert_name))
Пример #25
0
    model = VideoModel(config.model)

    if config.training.num_gpu > 0:
        model = model.cuda()
        if config.training.num_gpu > 1:
            device_ids = list(range(config.training.num_gpu))
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        logger.info('Loaded the model to %d GPUs' % config.training.num_gpu)

    n_params, enc, dec, fir_enc = count_parameters(model)
    logger.info('# the number of parameters in the whole model: %d' % n_params)
    logger.info('# the number of parameters in the Encoder: %d' % enc)
    logger.info('# the number of parameters in the Decoder: %d' %
                (n_params - enc))
    optimizer = Optimizer(model.parameters(), config.optim)
    logger.info('Created a %s optimizer.' % config.optim.type)

    start_epoch = 0

    # create a visualizer
    if config.training.visualization:
        visualizer = SummaryWriter(os.path.join(exp_name, 'log'))
        logger.info('Created a visualizer.')
    else:
        visualizer = None

    for epoch in range(start_epoch, config.training.epochs):
        # eval_ctc_model(epoch, config, model, val_data, logger, visualizer)
        train_ctc_model(epoch, config, model, train_data, optimizer, logger,
                        visualizer)
Пример #26
0
                                 n_layers=opt.rnn_layers,
                                 rnn_cell='lstm',
                                 dropout_p=0.2,
                                 use_attention=True,
                                 bidirectional=bidirectional,
                                 eos_id=tgt.eos_id,
                                 sos_id=tgt.sos_id)
            seq2seq = Seq2seq(encoder, decoder)
            if torch.cuda.is_available():
                seq2seq.cuda()

            for param in seq2seq.parameters():
                param.data.uniform_(-0.08, 0.08)

    optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters(),
                                           lr=opt.lr,
                                           betas=(0.9, 0.995)),
                          max_grad_norm=opt.grad_norm)
    # scheduler = StepLR(optimizer.optimizer, 1)
    # optimizer.set_scheduler(scheduler)

    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = CrossEntropyLoss(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    t = SupervisedTrainer(loss=loss,
                          batch_size=batch_size,
                          checkpoint_every=100,
                          print_every=10,
                          expt_dir=opt.expt_dir)