def supervise_evaluate_loop(agent, dev_it, dataset='iwslt', pair='fr_en'):
    dev_metrics = Metrics('s2p_dev', *['nll'])
    with torch.no_grad():
        agent.eval()
        trg_corpus, hyp_corpus = [], []

        for j, dev_batch in enumerate(dev_it):
            if dataset == "iwslt" or dataset == 'iwslt_small':
                src, src_len = dev_batch.src
                trg, trg_len = dev_batch.trg
                trg_field = dev_batch.dataset.fields['trg']
            elif dataset == "multi30k":
                src_lang, trg_lang = pair.split("_")
                src, src_len = dev_batch.__dict__[src_lang]
                trg, trg_len = dev_batch.__dict__[trg_lang]
                trg_field = dev_batch.dataset.fields[trg_lang]
            logits, _ = agent(src[:, 1:], src_len - 1, trg[:, :-1])
            nll = F.cross_entropy(logits,
                                  trg[:, 1:].contiguous().view(-1),
                                  reduction='mean',
                                  ignore_index=0)
            num_trg = (trg_len - 1).sum().item()
            dev_metrics.accumulate(num_trg, **{'nll': nll})
            hyp = agent.decode(src, src_len, 'greedy', 0)
            trg_corpus.extend(trg_field.reverse(trg, unbpe=True))
            hyp_corpus.extend(trg_field.reverse(hyp, unbpe=True))
        bleu = computeBLEU(hyp_corpus, trg_corpus, corpus=True)
    return dev_metrics, bleu
Esempio n. 2
0
def get_fr_en_imitate_stats(args, model, dev_it, monitor_names, extra_input):
    """ En BLUE, LM score and img prediction """
    model.eval()
    eval_metrics = Metrics('dev_loss', *monitor_names, data_type="avg")
    eval_metrics.reset()
    with torch.no_grad():
        unbpe = True
        en_corpus = []
        en_hyp = []

        for j, dev_batch in enumerate(dev_it):
            en_corpus.extend(args.EN.reverse(dev_batch.en[0], unbpe=unbpe))
            en_msg, en_msg_len = model.fr_en_speak(dev_batch,
                                                   is_training=False)
            en_hyp.extend(args.EN.reverse(en_msg, unbpe=unbpe))
            results, _ = model.get_grounding(
                en_msg,
                en_msg_len,
                dev_batch,
                en_lm=extra_input["en_lm"],
                all_img=extra_input["img"]['multi30k'][1],
                ranker=extra_input["ranker"])
            # Get entropy
            neg_Hs = model.fr_en.dec.neg_Hs  # (batch_size, en_msg_len)
            neg_Hs = neg_Hs.mean()  # (1,)
            results["neg_Hs"] = neg_Hs
            if len(monitor_names) > 0:
                eval_metrics.accumulate(len(dev_batch), **results)
            if args.debug:
                break

        bleu_en = computeBLEU(en_hyp, en_corpus, corpus=True)
        stats = eval_metrics.__dict__['metrics']
        stats['bleu_en'] = bleu_en[0]
        return stats
def valid_model(args, model, valid_img_feats, valid_caps, valid_lens):
    model.eval()
    batch_size = 32
    start = 0
    val_metrics = Metrics('val_loss', 'loss', data_type="avg")
    with torch.no_grad():
        while start <= valid_img_feats.shape[0]:
            cap_id = random.randint(0, 4)
            end = start + batch_size
            batch_img_feat = cuda(valid_img_feats[start: end])
            batch_ens = cuda(valid_caps[cap_id][start: end])
            batch_lens = cuda(valid_lens[cap_id][start: end])
            R = model(batch_ens[:, 1:], batch_lens - 1, batch_img_feat)
            if args.img_pred_loss == "vse":
                R['loss'] = R['loss'].sum()
            elif args.img_pred_loss == "mse":
                R['loss'] = R['loss'].mean()
            else:
                raise ValueError
            val_metrics.accumulate(batch_size, R['loss'])
            start = end
        return val_metrics
Esempio n. 4
0
    def evaluate_communication(self):
        """ Use greedy decoding and check scores like BLEU, language model and grounding """
        eval_metrics = Metrics('dev_loss',
                               *self.monitor_names,
                               data_type="avg")
        eval_metrics.reset()
        with torch.no_grad():
            unbpe = True
            self.model.eval()
            en_corpus, de_corpus = [], []
            en_hyp, de_hyp = [], []

            for j, dev_batch in enumerate(self.dev_it):
                en_corpus.extend(
                    self.args.EN.reverse(dev_batch.en[0], unbpe=unbpe))
                de_corpus.extend(
                    self.args.DE.reverse(dev_batch.de[0], unbpe=unbpe))

                en_msg, de_msg, en_msg_len, _ = self.model.decode(dev_batch)
                en_hyp.extend(self.args.EN.reverse(en_msg, unbpe=unbpe))
                de_hyp.extend(self.args.DE.reverse(de_msg, unbpe=unbpe))
                results, _ = self.model.get_grounding(
                    en_msg,
                    en_msg_len,
                    dev_batch,
                    en_lm=self.extra_input["en_lm"],
                    all_img=self.extra_input["img"]['multi30k'][1],
                    ranker=self.extra_input["ranker"])
                # Get entropy
                neg_Hs = self.model.fr_en.dec.neg_Hs  # (batch_size, en_msg_len)
                neg_Hs = neg_Hs.mean()  # (1,)
                results["neg_Hs"] = neg_Hs
                if len(self.monitor_names) > 0:
                    eval_metrics.accumulate(len(dev_batch), **results)

            bleu_en = computeBLEU(en_hyp, en_corpus, corpus=True)
            bleu_de = computeBLEU(de_hyp, de_corpus, corpus=True)
            self.args.logger.info(eval_metrics)
            self.args.logger.info("Fr-En {} : {}".format(
                'valid', print_bleu(bleu_en)))
            self.args.logger.info("En-De {} : {}".format(
                'valid', print_bleu(bleu_de)))
            return eval_metrics, bleu_en, bleu_de, en_corpus, en_hyp, de_hyp
Esempio n. 5
0
def train_model(args, model, iterators):
    (train_it, dev_it) = iterators

    if not args.debug:
        decoding_path = Path(join(args.decoding_path, args.id_str))
        decoding_path.mkdir(parents=True, exist_ok=True)
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(join(args.event_path, args.id_str))

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    extra_loss_names = []
    train_metrics = Metrics('train_loss',
                            'nll',
                            *extra_loss_names,
                            data_type="avg")
    dev_metrics = Metrics('dev_loss',
                          'nll',
                          *extra_loss_names,
                          data_type="avg")
    best = Best(max,
                'dev_bleu',
                'iters',
                model=model,
                opt=opt,
                path=join(args.model_path, args.id_str),
                gpu=args.gpu,
                debug=args.debug)

    for iters, train_batch in enumerate(train_it):
        if iters >= args.max_training_steps:
            args.logger.info(
                'stopping training after {} training steps'.format(
                    args.max_training_steps))
            break

        if not args.debug and hasattr(
                args, 'save_every') and iters % args.save_every == 0:
            args.logger.info(
                'save (back-up) checkpoints at iters={}'.format(iters))
            with torch.cuda.device(args.gpu):
                torch.save(
                    model.state_dict(),
                    '{}_iter={}.pt'.format(args.model_path + args.id_str,
                                           iters))
                torch.save([iters, opt.state_dict()],
                           '{}_iter={}.pt.states'.format(
                               args.model_path + args.id_str, iters))

        if iters % args.eval_every == 0:
            dev_metrics.reset()
            dev_bleu = valid_model(args, model, dev_it, dev_metrics, 'argmax')
            if not args.debug:
                write_tb(writer, ['nll'], [dev_metrics.nll],
                         iters,
                         prefix="dev/")
                write_tb(writer, [
                    'bleu', *("p_1 p_2 p_3 p_4".split()), 'bp', 'len_ref',
                    'len_hyp'
                ],
                         dev_bleu,
                         iters,
                         prefix="bleu/")
            best.accumulate(dev_bleu[0], iters)
            args.logger.info(best)
            """
            if args.early_stop and (iters - best.iters) // args.eval_every > args.patience:
                args.logger.info("Early stopping.")
                break
            """

        model.train()

        def get_lr_anneal(iters):
            lr_end = 1e-5
            return max(0, (args.lr - lr_end) *
                       (args.linear_anneal_steps - iters) /
                       args.linear_anneal_steps) + lr_end

        if args.lr_anneal == "linear":
            opt.param_groups[0]['lr'] = get_lr_anneal(iters)

        opt.zero_grad()

        batch_size = len(train_batch)
        if args.dataset == "iwslt" or args.dataset == 'iwslt_small':
            src, src_len = train_batch.src
            trg, trg_len = train_batch.trg
        elif args.dataset == "multi30k":
            src_lang, trg_lang = args.pair.split("_")
            src, src_len = train_batch.__dict__[src_lang]
            trg, trg_len = train_batch.__dict__[trg_lang]
        else:
            raise ValueError

        # NOTE encoder never receives <BOS> token
        # because during communication, Agent A's decoder will never output <BOS>
        logits, _ = model(src[:, 1:], src_len - 1, trg[:, :-1])
        nll = F.cross_entropy(logits,
                              trg[:, 1:].contiguous().view(-1),
                              reduction='mean',
                              ignore_index=0)
        num_trg = (trg_len - 1).sum().item()
        train_metrics.accumulate(num_trg, nll.item())

        if args.grad_clip > 0:
            total_norm = nn.utils.clip_grad_norm_(params, args.grad_clip)
        nll.backward()
        opt.step()

        if iters % args.print_every == 0:
            args.logger.info("update {} : {}".format(iters,
                                                     str(train_metrics)))
            if not args.debug:
                write_tb(writer, ['nll', 'lr'],
                         [train_metrics.nll, opt.param_groups[0]['lr']],
                         iters,
                         prefix="train/")
            train_metrics.reset()
def train_model(args, model):

    resnet = torchvision.models.resnet152(pretrained=True)
    resnet = nn.Sequential(*list(resnet.children())[:-1])
    resnet = nn.DataParallel(resnet).cuda()
    resnet.eval()

    if not args.debug:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter( args.event_path + args.id_str)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    loss_names, loss_cos = ["loss"], {"loss":1.0}
    monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split()

    train_metrics = Metrics('train_loss', *loss_names, data_type="avg")
    best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \
                gpu=args.gpu, debug=args.debug)

    # Train dataset
    args.logger.info("Loading train imgs...")
    train_dataset = ImageFolderWithPaths(os.path.join(args.data_dir, 'flickr30k'), preprocess_rc)
    train_imgs = open(os.path.join(args.data_dir, 'flickr30k/train.txt'), 'r').readlines()
    train_imgs = [x.strip() for x in train_imgs if x.strip() != ""]
    train_dataset.samples = [x for x in train_dataset.samples if x[0].split("/")[-1] in train_imgs]
    train_dataset.imgs = [x for x in train_dataset.imgs if x[0].split("/")[-1] in train_imgs]
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8,
                                               pin_memory=False)
    args.logger.info("Train loader built!")

    en_vocab = TextVocab(counter=torch.load(os.path.join(args.data_dir, 'bpe/vocab.en.pth')))
    word2idx = en_vocab.stoi
    train_en = [open(os.path.join(args.data_dir, 'flickr30k/caps', 'train.{}.bpe'.format(idx+1))).readlines() for idx in range(5)]
    train_en = [[["<bos>"] + sentence.strip().split() + ["<eos>"] for sentence in doc if sentence.strip() != "" ] for doc in train_en]
    train_en = [[[word2idx[word] for word in sentence] for sentence in doc] for doc in train_en]

    args.logger.info("Train corpus built!")

    # Valid dataset
    valid_img_feats = torch.tensor(torch.load(os.path.join(args.data_dir, 'flickr30k/val_feat.pth')))
    valid_ens = []
    valid_en_lens = []
    for idx in range(5):
        valid_en = []
        with open(os.path.join(args.data_dir, 'flickr30k/caps', 'val.{}.bpe'.format(idx+1))) as f:
            for line in f:
                line = line.strip()
                if line == "":
                    continue

                words = ["<bos>"] + line.split() + ["<eos>"]
                words = [word2idx[word] for word in words]
                valid_en.append(words)

        # Pad
        valid_en_len = [len(sent) for sent in valid_en]
        valid_en = [np.lib.pad(xx, (0, max(valid_en_len) - len(xx)), 'constant', constant_values=(0, 0)) for xx in valid_en]
        valid_ens.append(torch.tensor(valid_en).long())
        valid_en_lens.append(torch.tensor(valid_en_len).long())
    args.logger.info("Valid corpus built!")

    iters = -1
    should_stop = False
    for epoch in range(999999999):
        if should_stop:
            break

        for idx, (train_img, lab, path) in enumerate(train_loader):
            iters += 1
            if iters > args.max_training_steps:
                should_stop = True
                break

            if iters % args.eval_every == 0:
                res = get_retrieve_result(args, model, valid_caps=valid_ens, valid_lens=valid_en_lens,
                                          valid_img_feats=valid_img_feats)
                val_metrics = valid_model(args, model, valid_img_feats=valid_img_feats,
                                          valid_caps=valid_ens, valid_lens=valid_en_lens)
                args.logger.info("[VALID] update {} : {}".format(iters, str(val_metrics)))
                if not args.debug:
                    write_tb(writer, monitor_names, res, iters, prefix="dev/")
                    write_tb(writer, loss_names, [val_metrics.__getattr__(name) for name in loss_names],
                             iters, prefix="dev/")

                best.accumulate((res[0]+res[3])/2, iters)
                args.logger.info('model:' + args.prefix + args.hp_str)
                args.logger.info('epoch {} iters {}'.format(epoch, iters))
                args.logger.info(best)

                if args.early_stop and (iters - best.iters) // args.eval_every > args.patience:
                    args.logger.info("Early stopping.")
                    return

            model.train()

            def get_lr_anneal(iters):
                lr_end = args.lr_min
                return max( 0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) /
                           args.linear_anneal_steps ) + lr_end

            if args.lr_anneal == "linear":
                opt.param_groups[0]['lr'] = get_lr_anneal(iters)

            opt.zero_grad()
            batch_size = len(path)
            path = [p.split("/")[-1] for p in path]
            sentence_idx = [train_imgs.index(p) for p in path]
            en = [train_en[random.randint(0, 4)][sentence_i] for sentence_i in sentence_idx]
            en_len = [len(x) for x in en]

            en = [ np.lib.pad( xx, (0, max(en_len) - len(xx)), 'constant', constant_values=(0,0) ) for xx in en ]
            en = cuda( torch.LongTensor( np.array(en).tolist() ) )
            en_len = cuda( torch.LongTensor( en_len ) )

            with torch.no_grad():
                train_img = resnet(train_img).view(batch_size, -1)
            R = model(en[:,1:], en_len-1, train_img)
            if args.img_pred_loss == "vse":
                R['loss'] = R['loss'].sum()
            elif args.img_pred_loss == "mse":
                R['loss'] = R['loss'].mean()
            else:
                raise Exception()

            total_loss = 0
            for loss_name in loss_names:
                total_loss += R[loss_name] * loss_cos[loss_name]

            train_metrics.accumulate(batch_size, *[R[name].item() for name in loss_names])

            total_loss.backward()
            if args.plot_grad:
                plot_grad(writer, model, iters)

            if args.grad_clip > 0:
                nn.utils.clip_grad_norm_(params, args.grad_clip)

            opt.step()

            if iters % args.eval_every == 0:
                args.logger.info("update {} : {}".format(iters, str(train_metrics)))

            if iters % args.eval_every == 0 and not args.debug:
                write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \
                         iters, prefix="train/")
                write_tb(writer, ['lr'], [opt.param_groups[0]['lr']], iters, prefix="train/")
                train_metrics.reset()
def train_model(args, model, iterators, extra_input):
    (train_its, dev_its) = iterators

    if not args.debug:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(args.event_path + args.id_str)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    loss_names, loss_cos = ["nll"], {"nll": 1.0}
    monitor_names = ["nll_rnd"]
    """
    if args.rep_pen_co > 0.0:
        loss_names.append("nll_cur")
        loss_cos["nll_cur"] = -1 * args.rep_pen_co
    else:
        monitor_names.append("nll_cur")
    """

    train_metrics = Metrics('train_loss', *loss_names, data_type="avg")
    dev_metrics = Metrics('dev_loss',
                          *loss_names,
                          *monitor_names,
                          data_type="avg")
    best = Best(min, 'loss', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \
                gpu=args.gpu, debug=args.debug)

    iters = 0
    should_stop = False
    for epoch in range(999999999):
        if should_stop:
            break

        for dataset in args.dataset.split("_"):
            if should_stop:
                break

            train_it = train_its[dataset]
            for _, train_batch in enumerate(train_it):
                if iters >= args.max_training_steps:
                    args.logger.info(
                        'stopping training after {} training steps'.format(
                            args.max_training_steps))
                    should_stop = True
                    break

                if iters % args.eval_every == 0:
                    dev_metrics.reset()
                    valid_model(args, model, dev_its['multi30k'], dev_metrics,
                                iters, loss_names, monitor_names, extra_input)
                    if not args.debug:
                        write_tb(writer, loss_names, [dev_metrics.__getattr__(name) for name in loss_names], \
                                 iters, prefix="dev/")
                        write_tb(writer, monitor_names, [dev_metrics.__getattr__(name) for name in monitor_names], \
                                 iters, prefix="dev/")
                    best.accumulate(dev_metrics.nll, iters)

                    args.logger.info('model:' + args.prefix + args.hp_str)
                    args.logger.info('epoch {} dataset {} iters {}'.format(
                        epoch, dataset, iters))
                    args.logger.info(best)

                    if args.early_stop and (
                            iters -
                            best.iters) // args.eval_every > args.patience:
                        args.logger.info("Early stopping.")
                        return

                model.train()

                def get_lr_anneal(iters):
                    lr_end = args.lr_min
                    return max(0, (args.lr - lr_end) *
                               (args.linear_anneal_steps - iters) /
                               args.linear_anneal_steps) + lr_end

                if args.lr_anneal == "linear":
                    opt.param_groups[0]['lr'] = get_lr_anneal(iters)

                opt.zero_grad()

                batch_size = len(train_batch)
                img_input = None if args.no_img else cuda(
                    extra_input["img"][dataset][0].index_select(
                        dim=0, index=train_batch.idx.cpu()))
                if dataset == "coco":
                    en, en_len = train_batch.__dict__[
                        "_" + str(random.randint(1, 5))]
                elif dataset == "multi30k":
                    en, en_len = train_batch.en

                decoded = model(en, img_input)
                R = {}
                R["nll"] = F.cross_entropy(decoded,
                                           en[:, 1:].contiguous().view(-1),
                                           ignore_index=0)
                #R["nll_cur"] = F.cross_entropy( decoded, en[:,:-1].contiguous().view(-1), ignore_index=0 )

                total_loss = 0
                for loss_name in loss_names:
                    total_loss += R[loss_name] * loss_cos[loss_name]

                train_metrics.accumulate(
                    batch_size, *[R[name].item() for name in loss_names])

                total_loss.backward()
                if args.plot_grad:
                    plot_grad(writer, model, iters)

                if args.grad_clip > 0:
                    nn.utils.clip_grad_norm_(params, args.grad_clip)

                opt.step()
                iters += 1

                if iters % args.eval_every == 0:
                    args.logger.info("update {} : {}".format(
                        iters, str(train_metrics)))

                if iters % args.eval_every == 0 and not args.debug:
                    write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \
                             iters, prefix="train/")
                    write_tb(writer, ['lr'], [opt.param_groups[0]['lr']],
                             iters,
                             prefix="train/")
                    train_metrics.reset()
Esempio n. 8
0
def train_model(args, model, iterators, extra_input):
    (train_its, dev_its) = iterators

    if not args.debug:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(args.event_path + args.id_str)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    loss_names, loss_cos = ["loss"], {"loss": 1.0}
    monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split()

    train_metrics = Metrics('train_loss', *loss_names, data_type="avg")
    best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \
                gpu=args.gpu, debug=args.debug)

    iters = 0
    for epoch in range(999999999):
        for dataset in args.dataset.split("_"):
            train_it = train_its[dataset]
            for _, train_batch in enumerate(train_it):
                iters += 1

                if iters % args.eval_every == 0:
                    R = valid_model(args, model, dev_its['multi30k'],
                                    extra_input)
                    if not args.debug:
                        write_tb(writer,
                                 monitor_names,
                                 R,
                                 iters,
                                 prefix="dev/")
                    best.accumulate((R[0] + R[3]) / 2, iters)

                    args.logger.info('model:' + args.prefix + args.hp_str)
                    args.logger.info('epoch {} dataset {} iters {}'.format(
                        epoch, dataset, iters))
                    args.logger.info(best)

                    if args.early_stop and (
                            iters -
                            best.iters) // args.eval_every > args.patience:
                        args.logger.info("Early stopping.")
                        return

                model.train()

                def get_lr_anneal(iters):
                    lr_end = args.lr_min
                    return max(0, (args.lr - lr_end) *
                               (args.linear_anneal_steps - iters) /
                               args.linear_anneal_steps) + lr_end

                if args.lr_anneal == "linear":
                    opt.param_groups[0]['lr'] = get_lr_anneal(iters)

                opt.zero_grad()

                batch_size = len(train_batch)
                img = extra_input["img"][dataset][0].index_select(
                    dim=0, index=train_batch.idx.cpu())  # (batch_size, D_img)
                en, en_len = train_batch.__dict__["_" +
                                                  str(random.randint(1, 5))]
                R = model(en[:, 1:], en_len - 1, cuda(img))
                R['loss'] = R['loss'].mean()

                total_loss = 0
                for loss_name in loss_names:
                    total_loss += R[loss_name] * loss_cos[loss_name]

                train_metrics.accumulate(
                    batch_size, *[R[name].item() for name in loss_names])

                total_loss.backward()
                if args.plot_grad:
                    plot_grad(writer, model, iters)

                if args.grad_clip > 0:
                    nn.utils.clip_grad_norm_(params, args.grad_clip)

                opt.step()

                if iters % args.eval_every == 0:
                    args.logger.info("update {} : {}".format(
                        iters, str(train_metrics)))

                if iters % args.eval_every == 0 and not args.debug:
                    write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \
                             iters, prefix="train/")
                    write_tb(writer, ['lr'], [opt.param_groups[0]['lr']],
                             iters,
                             prefix="train/")
                    train_metrics.reset()
Esempio n. 9
0
    def start(self):
        # Prepare Metrics
        train_metrics = Metrics('train_loss',
                                *list(self.loss_cos.keys()),
                                *self.monitor_names,
                                data_type="avg")
        best = Best(max,
                    'de_bleu',
                    'en_bleu',
                    'iters',
                    model=self.model,
                    opt=self.opt,
                    path=self.args.model_path + self.args.id_str,
                    gpu=self.args.gpu,
                    debug=self.args.debug)
        # Determine when to stop iterlearn
        iters = self.extra_input['resume']['iters'] if self.resume else 0
        self.args.logger.info('Start Training at iters={}'.format(iters))
        try:
            train_it = iter(self.train_it)
            while iters < self.args.max_training_steps:
                train_batch = train_it.__next__()
                if iters >= self.args.max_training_steps:
                    self.args.logger.info(
                        'stopping training after {} training steps'.format(
                            self.args.max_training_steps))
                    break

                self._maybe_save(iters)

                if iters % self.args.eval_every == 0:
                    self.model.eval()
                    self.evaluate(iters, best)
                    self.supervise_evaluate(iters)
                    if self.args.plot_grad:
                        self._plot_grad(iters)

                self.model.train()
                self.train_step(iters, train_batch, train_metrics)

                if iters % self.args.eval_every == 0:
                    self.args.logger.info("update {} : {}".format(
                        iters, str(train_metrics)))
                    train_stats = {}
                    train_stats.update({
                        name: train_metrics.__getattr__(name)
                        for name in self.loss_cos
                    })
                    train_stats.update({
                        name: train_metrics.__getattr__(name)
                        for name in self.monitor_names
                    })
                    train_stats['lr'] = self.opt.param_groups[0]['lr']
                    write_tb(self.writer, train_stats, iters, prefix="train/")
                    train_metrics.reset()

                iters += 1
        except (InterruptedError, KeyboardInterrupt):
            # End Gracefully
            self.end_gracefully(iters)
        self.writer.flush()
        self.writer.close()