Example #1
0
    def evaluate(self, iters, best):
        """ Free Run """
        eval_metric, bleu_en, bleu_de, en_corpus, en_hyp, de_hyp = self.evaluate_communication(
        )
        write_tb(self.writer, {
            'bleu': bleu_en[0],
            'len_hyp': bleu_en[-1]
        },
                 iters,
                 prefix='bleu_en/')
        write_tb(self.writer, {
            'bleu': bleu_de[0],
            'len_hyp': bleu_de[-1]
        },
                 iters,
                 prefix='bleu_de/')
        write_tb(self.writer, {
            "bleu_en": bleu_en[0],
            "bleu_de": bleu_de[0]
        },
                 iters,
                 prefix="eval/")
        write_tb(self.writer, {
            name: eval_metric.__getattr__(name)
            for name in self.monitor_names
        },
                 iters,
                 prefix="eval/")
        self.args.logger.info('model:' + self.args.prefix + self.args.hp_str)
        best.accumulate(bleu_de[0], bleu_en[0], iters)
        self.args.logger.info(best)

        # Save decoding results
        dest_folders = [
            Path(self.args.decoding_path) / self.args.id_str / name for name in
            ["en_ref", "de_hyp_{}".format(iters), "en_hyp_{}".format(iters)]
        ]
        for (dest, string) in zip(dest_folders, [en_corpus, de_hyp, en_hyp]):
            dest.write_text("\n".join(string), encoding="utf-8")
Example #2
0
def train_model(args, model, iterators):
    (train_it, dev_it) = iterators

    if not args.debug:
        decoding_path = Path(join(args.decoding_path, args.id_str))
        decoding_path.mkdir(parents=True, exist_ok=True)
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(join(args.event_path, args.id_str))

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    extra_loss_names = []
    train_metrics = Metrics('train_loss',
                            'nll',
                            *extra_loss_names,
                            data_type="avg")
    dev_metrics = Metrics('dev_loss',
                          'nll',
                          *extra_loss_names,
                          data_type="avg")
    best = Best(max,
                'dev_bleu',
                'iters',
                model=model,
                opt=opt,
                path=join(args.model_path, args.id_str),
                gpu=args.gpu,
                debug=args.debug)

    for iters, train_batch in enumerate(train_it):
        if iters >= args.max_training_steps:
            args.logger.info(
                'stopping training after {} training steps'.format(
                    args.max_training_steps))
            break

        if not args.debug and hasattr(
                args, 'save_every') and iters % args.save_every == 0:
            args.logger.info(
                'save (back-up) checkpoints at iters={}'.format(iters))
            with torch.cuda.device(args.gpu):
                torch.save(
                    model.state_dict(),
                    '{}_iter={}.pt'.format(args.model_path + args.id_str,
                                           iters))
                torch.save([iters, opt.state_dict()],
                           '{}_iter={}.pt.states'.format(
                               args.model_path + args.id_str, iters))

        if iters % args.eval_every == 0:
            dev_metrics.reset()
            dev_bleu = valid_model(args, model, dev_it, dev_metrics, 'argmax')
            if not args.debug:
                write_tb(writer, ['nll'], [dev_metrics.nll],
                         iters,
                         prefix="dev/")
                write_tb(writer, [
                    'bleu', *("p_1 p_2 p_3 p_4".split()), 'bp', 'len_ref',
                    'len_hyp'
                ],
                         dev_bleu,
                         iters,
                         prefix="bleu/")
            best.accumulate(dev_bleu[0], iters)
            args.logger.info(best)
            """
            if args.early_stop and (iters - best.iters) // args.eval_every > args.patience:
                args.logger.info("Early stopping.")
                break
            """

        model.train()

        def get_lr_anneal(iters):
            lr_end = 1e-5
            return max(0, (args.lr - lr_end) *
                       (args.linear_anneal_steps - iters) /
                       args.linear_anneal_steps) + lr_end

        if args.lr_anneal == "linear":
            opt.param_groups[0]['lr'] = get_lr_anneal(iters)

        opt.zero_grad()

        batch_size = len(train_batch)
        if args.dataset == "iwslt" or args.dataset == 'iwslt_small':
            src, src_len = train_batch.src
            trg, trg_len = train_batch.trg
        elif args.dataset == "multi30k":
            src_lang, trg_lang = args.pair.split("_")
            src, src_len = train_batch.__dict__[src_lang]
            trg, trg_len = train_batch.__dict__[trg_lang]
        else:
            raise ValueError

        # NOTE encoder never receives <BOS> token
        # because during communication, Agent A's decoder will never output <BOS>
        logits, _ = model(src[:, 1:], src_len - 1, trg[:, :-1])
        nll = F.cross_entropy(logits,
                              trg[:, 1:].contiguous().view(-1),
                              reduction='mean',
                              ignore_index=0)
        num_trg = (trg_len - 1).sum().item()
        train_metrics.accumulate(num_trg, nll.item())

        if args.grad_clip > 0:
            total_norm = nn.utils.clip_grad_norm_(params, args.grad_clip)
        nll.backward()
        opt.step()

        if iters % args.print_every == 0:
            args.logger.info("update {} : {}".format(iters,
                                                     str(train_metrics)))
            if not args.debug:
                write_tb(writer, ['nll', 'lr'],
                         [train_metrics.nll, opt.param_groups[0]['lr']],
                         iters,
                         prefix="train/")
            train_metrics.reset()
def train_model(args, model):

    resnet = torchvision.models.resnet152(pretrained=True)
    resnet = nn.Sequential(*list(resnet.children())[:-1])
    resnet = nn.DataParallel(resnet).cuda()
    resnet.eval()

    if not args.debug:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter( args.event_path + args.id_str)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    loss_names, loss_cos = ["loss"], {"loss":1.0}
    monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split()

    train_metrics = Metrics('train_loss', *loss_names, data_type="avg")
    best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \
                gpu=args.gpu, debug=args.debug)

    # Train dataset
    args.logger.info("Loading train imgs...")
    train_dataset = ImageFolderWithPaths(os.path.join(args.data_dir, 'flickr30k'), preprocess_rc)
    train_imgs = open(os.path.join(args.data_dir, 'flickr30k/train.txt'), 'r').readlines()
    train_imgs = [x.strip() for x in train_imgs if x.strip() != ""]
    train_dataset.samples = [x for x in train_dataset.samples if x[0].split("/")[-1] in train_imgs]
    train_dataset.imgs = [x for x in train_dataset.imgs if x[0].split("/")[-1] in train_imgs]
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8,
                                               pin_memory=False)
    args.logger.info("Train loader built!")

    en_vocab = TextVocab(counter=torch.load(os.path.join(args.data_dir, 'bpe/vocab.en.pth')))
    word2idx = en_vocab.stoi
    train_en = [open(os.path.join(args.data_dir, 'flickr30k/caps', 'train.{}.bpe'.format(idx+1))).readlines() for idx in range(5)]
    train_en = [[["<bos>"] + sentence.strip().split() + ["<eos>"] for sentence in doc if sentence.strip() != "" ] for doc in train_en]
    train_en = [[[word2idx[word] for word in sentence] for sentence in doc] for doc in train_en]

    args.logger.info("Train corpus built!")

    # Valid dataset
    valid_img_feats = torch.tensor(torch.load(os.path.join(args.data_dir, 'flickr30k/val_feat.pth')))
    valid_ens = []
    valid_en_lens = []
    for idx in range(5):
        valid_en = []
        with open(os.path.join(args.data_dir, 'flickr30k/caps', 'val.{}.bpe'.format(idx+1))) as f:
            for line in f:
                line = line.strip()
                if line == "":
                    continue

                words = ["<bos>"] + line.split() + ["<eos>"]
                words = [word2idx[word] for word in words]
                valid_en.append(words)

        # Pad
        valid_en_len = [len(sent) for sent in valid_en]
        valid_en = [np.lib.pad(xx, (0, max(valid_en_len) - len(xx)), 'constant', constant_values=(0, 0)) for xx in valid_en]
        valid_ens.append(torch.tensor(valid_en).long())
        valid_en_lens.append(torch.tensor(valid_en_len).long())
    args.logger.info("Valid corpus built!")

    iters = -1
    should_stop = False
    for epoch in range(999999999):
        if should_stop:
            break

        for idx, (train_img, lab, path) in enumerate(train_loader):
            iters += 1
            if iters > args.max_training_steps:
                should_stop = True
                break

            if iters % args.eval_every == 0:
                res = get_retrieve_result(args, model, valid_caps=valid_ens, valid_lens=valid_en_lens,
                                          valid_img_feats=valid_img_feats)
                val_metrics = valid_model(args, model, valid_img_feats=valid_img_feats,
                                          valid_caps=valid_ens, valid_lens=valid_en_lens)
                args.logger.info("[VALID] update {} : {}".format(iters, str(val_metrics)))
                if not args.debug:
                    write_tb(writer, monitor_names, res, iters, prefix="dev/")
                    write_tb(writer, loss_names, [val_metrics.__getattr__(name) for name in loss_names],
                             iters, prefix="dev/")

                best.accumulate((res[0]+res[3])/2, iters)
                args.logger.info('model:' + args.prefix + args.hp_str)
                args.logger.info('epoch {} iters {}'.format(epoch, iters))
                args.logger.info(best)

                if args.early_stop and (iters - best.iters) // args.eval_every > args.patience:
                    args.logger.info("Early stopping.")
                    return

            model.train()

            def get_lr_anneal(iters):
                lr_end = args.lr_min
                return max( 0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) /
                           args.linear_anneal_steps ) + lr_end

            if args.lr_anneal == "linear":
                opt.param_groups[0]['lr'] = get_lr_anneal(iters)

            opt.zero_grad()
            batch_size = len(path)
            path = [p.split("/")[-1] for p in path]
            sentence_idx = [train_imgs.index(p) for p in path]
            en = [train_en[random.randint(0, 4)][sentence_i] for sentence_i in sentence_idx]
            en_len = [len(x) for x in en]

            en = [ np.lib.pad( xx, (0, max(en_len) - len(xx)), 'constant', constant_values=(0,0) ) for xx in en ]
            en = cuda( torch.LongTensor( np.array(en).tolist() ) )
            en_len = cuda( torch.LongTensor( en_len ) )

            with torch.no_grad():
                train_img = resnet(train_img).view(batch_size, -1)
            R = model(en[:,1:], en_len-1, train_img)
            if args.img_pred_loss == "vse":
                R['loss'] = R['loss'].sum()
            elif args.img_pred_loss == "mse":
                R['loss'] = R['loss'].mean()
            else:
                raise Exception()

            total_loss = 0
            for loss_name in loss_names:
                total_loss += R[loss_name] * loss_cos[loss_name]

            train_metrics.accumulate(batch_size, *[R[name].item() for name in loss_names])

            total_loss.backward()
            if args.plot_grad:
                plot_grad(writer, model, iters)

            if args.grad_clip > 0:
                nn.utils.clip_grad_norm_(params, args.grad_clip)

            opt.step()

            if iters % args.eval_every == 0:
                args.logger.info("update {} : {}".format(iters, str(train_metrics)))

            if iters % args.eval_every == 0 and not args.debug:
                write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \
                         iters, prefix="train/")
                write_tb(writer, ['lr'], [opt.param_groups[0]['lr']], iters, prefix="train/")
                train_metrics.reset()
def train_model(args, model, iterators, extra_input):
    (train_its, dev_its) = iterators

    if not args.debug:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(args.event_path + args.id_str)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    loss_names, loss_cos = ["nll"], {"nll": 1.0}
    monitor_names = ["nll_rnd"]
    """
    if args.rep_pen_co > 0.0:
        loss_names.append("nll_cur")
        loss_cos["nll_cur"] = -1 * args.rep_pen_co
    else:
        monitor_names.append("nll_cur")
    """

    train_metrics = Metrics('train_loss', *loss_names, data_type="avg")
    dev_metrics = Metrics('dev_loss',
                          *loss_names,
                          *monitor_names,
                          data_type="avg")
    best = Best(min, 'loss', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \
                gpu=args.gpu, debug=args.debug)

    iters = 0
    should_stop = False
    for epoch in range(999999999):
        if should_stop:
            break

        for dataset in args.dataset.split("_"):
            if should_stop:
                break

            train_it = train_its[dataset]
            for _, train_batch in enumerate(train_it):
                if iters >= args.max_training_steps:
                    args.logger.info(
                        'stopping training after {} training steps'.format(
                            args.max_training_steps))
                    should_stop = True
                    break

                if iters % args.eval_every == 0:
                    dev_metrics.reset()
                    valid_model(args, model, dev_its['multi30k'], dev_metrics,
                                iters, loss_names, monitor_names, extra_input)
                    if not args.debug:
                        write_tb(writer, loss_names, [dev_metrics.__getattr__(name) for name in loss_names], \
                                 iters, prefix="dev/")
                        write_tb(writer, monitor_names, [dev_metrics.__getattr__(name) for name in monitor_names], \
                                 iters, prefix="dev/")
                    best.accumulate(dev_metrics.nll, iters)

                    args.logger.info('model:' + args.prefix + args.hp_str)
                    args.logger.info('epoch {} dataset {} iters {}'.format(
                        epoch, dataset, iters))
                    args.logger.info(best)

                    if args.early_stop and (
                            iters -
                            best.iters) // args.eval_every > args.patience:
                        args.logger.info("Early stopping.")
                        return

                model.train()

                def get_lr_anneal(iters):
                    lr_end = args.lr_min
                    return max(0, (args.lr - lr_end) *
                               (args.linear_anneal_steps - iters) /
                               args.linear_anneal_steps) + lr_end

                if args.lr_anneal == "linear":
                    opt.param_groups[0]['lr'] = get_lr_anneal(iters)

                opt.zero_grad()

                batch_size = len(train_batch)
                img_input = None if args.no_img else cuda(
                    extra_input["img"][dataset][0].index_select(
                        dim=0, index=train_batch.idx.cpu()))
                if dataset == "coco":
                    en, en_len = train_batch.__dict__[
                        "_" + str(random.randint(1, 5))]
                elif dataset == "multi30k":
                    en, en_len = train_batch.en

                decoded = model(en, img_input)
                R = {}
                R["nll"] = F.cross_entropy(decoded,
                                           en[:, 1:].contiguous().view(-1),
                                           ignore_index=0)
                #R["nll_cur"] = F.cross_entropy( decoded, en[:,:-1].contiguous().view(-1), ignore_index=0 )

                total_loss = 0
                for loss_name in loss_names:
                    total_loss += R[loss_name] * loss_cos[loss_name]

                train_metrics.accumulate(
                    batch_size, *[R[name].item() for name in loss_names])

                total_loss.backward()
                if args.plot_grad:
                    plot_grad(writer, model, iters)

                if args.grad_clip > 0:
                    nn.utils.clip_grad_norm_(params, args.grad_clip)

                opt.step()
                iters += 1

                if iters % args.eval_every == 0:
                    args.logger.info("update {} : {}".format(
                        iters, str(train_metrics)))

                if iters % args.eval_every == 0 and not args.debug:
                    write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \
                             iters, prefix="train/")
                    write_tb(writer, ['lr'], [opt.param_groups[0]['lr']],
                             iters,
                             prefix="train/")
                    train_metrics.reset()
Example #5
0
def train_model(args, model, iterators, extra_input):
    (train_its, dev_its) = iterators

    if not args.debug:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(args.event_path + args.id_str)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    loss_names, loss_cos = ["loss"], {"loss": 1.0}
    monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split()

    train_metrics = Metrics('train_loss', *loss_names, data_type="avg")
    best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \
                gpu=args.gpu, debug=args.debug)

    iters = 0
    for epoch in range(999999999):
        for dataset in args.dataset.split("_"):
            train_it = train_its[dataset]
            for _, train_batch in enumerate(train_it):
                iters += 1

                if iters % args.eval_every == 0:
                    R = valid_model(args, model, dev_its['multi30k'],
                                    extra_input)
                    if not args.debug:
                        write_tb(writer,
                                 monitor_names,
                                 R,
                                 iters,
                                 prefix="dev/")
                    best.accumulate((R[0] + R[3]) / 2, iters)

                    args.logger.info('model:' + args.prefix + args.hp_str)
                    args.logger.info('epoch {} dataset {} iters {}'.format(
                        epoch, dataset, iters))
                    args.logger.info(best)

                    if args.early_stop and (
                            iters -
                            best.iters) // args.eval_every > args.patience:
                        args.logger.info("Early stopping.")
                        return

                model.train()

                def get_lr_anneal(iters):
                    lr_end = args.lr_min
                    return max(0, (args.lr - lr_end) *
                               (args.linear_anneal_steps - iters) /
                               args.linear_anneal_steps) + lr_end

                if args.lr_anneal == "linear":
                    opt.param_groups[0]['lr'] = get_lr_anneal(iters)

                opt.zero_grad()

                batch_size = len(train_batch)
                img = extra_input["img"][dataset][0].index_select(
                    dim=0, index=train_batch.idx.cpu())  # (batch_size, D_img)
                en, en_len = train_batch.__dict__["_" +
                                                  str(random.randint(1, 5))]
                R = model(en[:, 1:], en_len - 1, cuda(img))
                R['loss'] = R['loss'].mean()

                total_loss = 0
                for loss_name in loss_names:
                    total_loss += R[loss_name] * loss_cos[loss_name]

                train_metrics.accumulate(
                    batch_size, *[R[name].item() for name in loss_names])

                total_loss.backward()
                if args.plot_grad:
                    plot_grad(writer, model, iters)

                if args.grad_clip > 0:
                    nn.utils.clip_grad_norm_(params, args.grad_clip)

                opt.step()

                if iters % args.eval_every == 0:
                    args.logger.info("update {} : {}".format(
                        iters, str(train_metrics)))

                if iters % args.eval_every == 0 and not args.debug:
                    write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \
                             iters, prefix="train/")
                    write_tb(writer, ['lr'], [opt.param_groups[0]['lr']],
                             iters,
                             prefix="train/")
                    train_metrics.reset()
Example #6
0
    def start(self):
        # Prepare Metrics
        train_metrics = Metrics('train_loss',
                                *list(self.loss_cos.keys()),
                                *self.monitor_names,
                                data_type="avg")
        best = Best(max,
                    'de_bleu',
                    'en_bleu',
                    'iters',
                    model=self.model,
                    opt=self.opt,
                    path=self.args.model_path + self.args.id_str,
                    gpu=self.args.gpu,
                    debug=self.args.debug)
        # Determine when to stop iterlearn
        iters = self.extra_input['resume']['iters'] if self.resume else 0
        self.args.logger.info('Start Training at iters={}'.format(iters))
        try:
            train_it = iter(self.train_it)
            while iters < self.args.max_training_steps:
                train_batch = train_it.__next__()
                if iters >= self.args.max_training_steps:
                    self.args.logger.info(
                        'stopping training after {} training steps'.format(
                            self.args.max_training_steps))
                    break

                self._maybe_save(iters)

                if iters % self.args.eval_every == 0:
                    self.model.eval()
                    self.evaluate(iters, best)
                    self.supervise_evaluate(iters)
                    if self.args.plot_grad:
                        self._plot_grad(iters)

                self.model.train()
                self.train_step(iters, train_batch, train_metrics)

                if iters % self.args.eval_every == 0:
                    self.args.logger.info("update {} : {}".format(
                        iters, str(train_metrics)))
                    train_stats = {}
                    train_stats.update({
                        name: train_metrics.__getattr__(name)
                        for name in self.loss_cos
                    })
                    train_stats.update({
                        name: train_metrics.__getattr__(name)
                        for name in self.monitor_names
                    })
                    train_stats['lr'] = self.opt.param_groups[0]['lr']
                    write_tb(self.writer, train_stats, iters, prefix="train/")
                    train_metrics.reset()

                iters += 1
        except (InterruptedError, KeyboardInterrupt):
            # End Gracefully
            self.end_gracefully(iters)
        self.writer.flush()
        self.writer.close()