Esempio n. 1
0
def main():
    seed = 19260817

    print(cmd_args)

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if cmd_args.ae_type == 'vae':
        ae = MolVAE()
    elif cmd_args.ae_type == 'autoenc':
        ae = MolAutoEncoder()
    else:
        raise Exception('unknown ae type %s' % cmd_args.ae_type)
    if cmd_args.mode == 'gpu':
        ae = ae.cuda()

    if cmd_args.saved_model is not None and cmd_args.saved_model != '':
        if os.path.isfile(cmd_args.saved_model):
            print('loading model from %s' % cmd_args.saved_model)
            ae.load_state_dict(torch.load(cmd_args.saved_model))

    assert cmd_args.encoder_type == 'cnn'

    optimizer = optim.Adam(ae.parameters(), lr=cmd_args.learning_rate)
    lr_scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True, min_lr=0.0001)

    train_binary, train_masks, valid_binary, valid_masks = load_data()
    print('num_train: %d\tnum_valid: %d' % (train_binary.shape[0], valid_binary.shape[0]))

    sample_idxes = list(range(train_binary.shape[0]))
    best_valid_loss = None
    for epoch in range(cmd_args.num_epochs):
        random.shuffle(sample_idxes)

        avg_loss = loop_dataset('train', ae, sample_idxes, train_binary, train_masks, optimizer)
        print('>>>>average \033[92mtraining\033[0m of epoch %d: loss %.5f perp %.5f kl %.5f' % (epoch, avg_loss[0], avg_loss[1], avg_loss[2]))        

        if epoch % 1 == 0:
            valid_loss = loop_dataset('valid', ae, list(range(valid_binary.shape[0])), valid_binary, valid_masks)
            print('        average \033[93mvalid\033[0m of epoch %d: loss %.5f perp %.5f kl %.5f' % (epoch, valid_loss[0], valid_loss[1], valid_loss[2]))
            valid_loss = valid_loss[0]
            lr_scheduler.step(valid_loss)
            torch.save(ae.state_dict(), cmd_args.save_dir + '/epoch-%d.model' % epoch)
            if best_valid_loss is None or valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                print('----saving to best model since this is the best valid loss so far.----')
                torch.save(ae.state_dict(), cmd_args.save_dir + '/epoch-best.model')
 def build_model(self):
     print ('==> Build model and setup loss and optimizer')
     #build model
     self.model = resnet101(pretrained= True, channel=3).cuda()
     #Loss function and optimizer
     self.criterion = nn.CrossEntropyLoss().cuda()
     self.optimizer = torch.optim.SGD(self.model.parameters(), self.lr, momentum=0.9)
     self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', patience=1,verbose=True)
Esempio n. 3
0
 def test_reduce_lr_on_plateau_state_dict(self):
     scheduler = ReduceLROnPlateau(self.opt, mode='min', factor=0.1, patience=2)
     for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]:
         scheduler.step(score)
     scheduler_copy = ReduceLROnPlateau(self.opt, mode='max', factor=0.5, patience=10)
     scheduler_copy.load_state_dict(scheduler.state_dict())
     for key in scheduler.__dict__.keys():
         if key not in {'optimizer', 'is_better'}:
             self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key], allow_inf=True)
Esempio n. 4
0
 def setup_training(self, optimizer, lr, min_lr, momentum, cooldown):
     assert self.model is not None
     if optimizer == 'SGD':
         self.optimizer = torch.optim.SGD(
             self.model.parameters(), lr, momentum=momentum)
     elif optimizer == 'Adadelta':
         self.optimizer = torch.optim.Adadelta(self.model.parameters(), lr)
     elif optimizer == 'Adam':
         self.optimizer = torch.optim.Adam(self.model.parameters(), lr)
     else:
         raise ValueError
     self.min_lr = min_lr
     self.scheduler = ReduceLROnPlateau(
         self.optimizer,
         'min',
         patience=0,
         cooldown=cooldown,
         factor=0.5,
         min_lr=min_lr)
     self.setup_evalutator()
    def __define_optimizer(self, learning_rate, weight_decay,
                           lr_drop_factor, lr_drop_patience, optimizer='Adam'):
        assert optimizer in ['RMSprop', 'Adam', 'Adadelta', 'SGD']

        parameters = ifilter(lambda p: p.requires_grad,
                             self.model.parameters())

        if optimizer == 'RMSprop':
            self.optimizer = optim.RMSprop(
                parameters, lr=learning_rate, weight_decay=weight_decay)
        elif optimizer == 'Adadelta':
            self.optimizer = optim.Adadelta(
                parameters, lr=learning_rate, weight_decay=weight_decay)
        elif optimizer == 'Adam':
            self.optimizer = optim.Adam(
                parameters, lr=learning_rate, weight_decay=weight_decay)
        elif optimizer == 'SGD':
            self.optimizer = optim.SGD(
                parameters, lr=learning_rate, momentum=0.9,
                weight_decay=weight_decay)

        self.lr_scheduler = ReduceLROnPlateau(
            self.optimizer, mode='min', factor=lr_drop_factor,
            patience=lr_drop_patience, verbose=True)
def main():
    net = FCN8s(num_classes=cityscapes.num_classes).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                               'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                               'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform,
                                      transform=input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
batch_size = 128
epochs = 400
use_previous_model = False
epoch_to_use = 10
in_joints = [0, 1, 2, 3, 4, 5]
f = True
window = 1000
print('Running for is_rnn value: ', is_rnn)

for num in ['60', '120', '180', '240', '300']:
    model = 'filtered_torque_' + num + 's'
    n = int(num)

    network = torqueLstmNetwork(batch_size, device).to(device)
    optimizer = torch.optim.Adam(network.parameters(), lr)
    scheduler = ReduceLROnPlateau(optimizer, verbose=False)

    train_dataset = indirectDataset(train_path,
                                    window,
                                    SKIP,
                                    in_joints,
                                    num=n,
                                    is_rnn=is_rnn,
                                    filter_signal=f)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True)

    loss_fn = torch.nn.MSELoss()

    start = time.time()
Esempio n. 8
0
def train(mode='train',
          train_path='train.conllx',
          model='dozat',
          dev_path='dev.conllx',
          test_path='test.conllx',
          ud=True,
          output_dir='output',
          emb_dim=0,
          char_emb_dim=0,
          char_model=None,
          tagger=None,
          batch_size=5000,
          n_iters=10,
          dropout_p=0.33,
          num_layers=1,
          print_every=1,
          eval_every=100,
          bi=True,
          lr=0.001,
          adam_beta1=0.9,
          adam_beta2=0.999,
          weight_decay=0.,
          plateau=False,
          resume=False,
          lr_decay=1.0,
          lr_decay_steps=5000,
          clip=5.,
          momentum=0,
          optimizer='adam',
          glove=True,
          seed=42,
          dim=0,
          window_size=0,
          num_filters=0,
          **kwargs):

    device = torch.device(type='cuda') if use_cuda else torch.device(
        type='cpu')

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    cfg = locals().copy()

    torch.manual_seed(seed)
    np.random.seed(seed)

    # load data component
    dataset_obj = ConllXDataset
    fields = get_data_fields()
    _form = fields['form'][-1]
    _pos = fields['pos'][-1]
    _chars = fields['chars'][-1]

    train_dataset = dataset_obj(train_path, fields)
    dev_dataset = dataset_obj(dev_path, fields)
    test_dataset = dataset_obj(test_path, fields)

    logger.info("Loaded %d train examples" % len(train_dataset))
    logger.info("Loaded %d dev examples" % len(dev_dataset))
    logger.info("Loaded %d test examples" % len(test_dataset))

    form_vocab_path = os.path.join(output_dir, 'vocab.form.pth.tar')
    pos_vocab_path = os.path.join(output_dir, 'vocab.pos.pth.tar')
    char_vocab_path = os.path.join(output_dir, 'vocab.char.pth.tar')

    if not resume:
        # build vocabularies
        # words have a min frequency of 2 to be included; others become <unk>
        # words without a Glove vector are initialized ~ N(0, 0.5) mimicking Glove

        # Note: this requires the latest torchtext development version from Github.
        # - git clone https://github.com/pytorch/text.git torchtext
        # - cd torchtext
        # - python setup.py build
        # - python setup.py install

        def unk_init(x):
            # return 0.01 * torch.randn(x)
            return torch.zeros(x)

        if glove:
            logger.info("Using Glove vectors")
            glove_vectors = GloVe(name='6B', dim=100)
            _form.build_vocab(train_dataset,
                              min_freq=2,
                              unk_init=unk_init,
                              vectors=glove_vectors)
            n_unks = 0
            unk_set = set()
            # for now, set UNK words manually
            # (torchtext does not seem to support it yet)
            for i, token in enumerate(_form.vocab.itos):
                if token not in glove_vectors.stoi:
                    n_unks += 1
                    unk_set.add(token)
                    _form.vocab.vectors[i] = unk_init(emb_dim)
            # print(n_unks, unk_set)

        else:
            _form.build_vocab(train_dataset, min_freq=2)

        _pos.build_vocab(train_dataset)
        _chars.build_vocab(train_dataset)

        # save vocabularies
        torch.save(_form.vocab, form_vocab_path)
        torch.save(_pos.vocab, pos_vocab_path)
        torch.save(_chars.vocab, char_vocab_path)

    else:
        # load vocabularies
        _form.vocab = torch.load(form_vocab_path)
        _pos.vocab = torch.load(pos_vocab_path)
        _chars.vocab = torch.load(char_vocab_path)

    print("First 10 vocabulary entries, words: ",
          " ".join(_form.vocab.itos[:10]))
    print("First 10 vocabulary entries, pos tags: ",
          " ".join(_pos.vocab.itos[:10]))
    print("First 10 vocabulary entries, chars: ",
          " ".join(_chars.vocab.itos[:10]))

    n_words = len(_form.vocab)
    n_tags = len(_pos.vocab)
    n_chars = len(_chars.vocab)

    def batch_size_fn(new, count, sofar):
        return len(new.form) + 1 + sofar

    # iterators
    train_iter = Iterator(train_dataset,
                          batch_size,
                          train=True,
                          sort_within_batch=True,
                          batch_size_fn=batch_size_fn,
                          device=device)
    dev_iter = Iterator(dev_dataset,
                        32,
                        train=False,
                        sort_within_batch=True,
                        device=device)
    test_iter = Iterator(test_dataset,
                         32,
                         train=False,
                         sort_within_batch=True,
                         device=device)

    # uncomment to see what a mini-batch looks like numerically
    # e.g. some things are being inserted dynamically (ROOT at the start of seq,
    #   padding items, maybe UNKs..)
    # batch = next(iter(train_iter))
    # print("form", batch.form)
    # print("pos", batch.pos)
    # print("deprel", batch.deprel)
    # print("head", batch.head)

    # if n_iters or eval_every are negative, we set them to that many
    # number of epochs
    iters_per_epoch = (len(train_dataset) // batch_size) + 1
    if eval_every < 0:
        logger.info("Setting eval_every to %d epoch(s) = %d iters" %
                    (-1 * eval_every, -1 * eval_every * iters_per_epoch))
        eval_every = iters_per_epoch * eval_every

    if n_iters < 0:
        logger.info("Setting n_iters to %d epoch(s) = %d iters" %
                    (-1 * n_iters, -1 * n_iters * iters_per_epoch))
        n_iters = -1 * n_iters * iters_per_epoch

    # load up the model
    model = Tagger(n_words=n_words,
                   n_tags=n_tags,
                   n_chars=n_chars,
                   form_vocab=_form.vocab,
                   char_vocab=_chars.vocab,
                   pos_vocab=_pos.vocab,
                   **cfg)

    # set word vectors
    if glove:
        _form.vocab.vectors = _form.vocab.vectors / torch.std(
            _form.vocab.vectors)
        # print(torch.std(_form.vocab.vectors))
        model.encoder.embedding.weight.data.copy_(_form.vocab.vectors)
        model.encoder.embedding.weight.requires_grad = True

    model = model.cuda() if use_cuda else model

    start_iter = 1
    best_iter = 0
    best_pos_acc = -1.
    test_pos_acc = -1.

    # optimizer and learning rate scheduler
    trainable_parameters = [p for p in model.parameters() if p.requires_grad]
    if optimizer == 'sgd':
        optimizer = torch.optim.SGD(trainable_parameters,
                                    lr=lr,
                                    momentum=momentum)
    else:
        optimizer = torch.optim.Adam(trainable_parameters,
                                     lr=lr,
                                     betas=(adam_beta1, adam_beta2))

    # learning rate schedulers
    if not plateau:
        scheduler = LambdaLR(optimizer, lr_lambda=lambda t: lr_decay**t)
    else:
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='max',
                                      factor=0.75,
                                      patience=5,
                                      min_lr=1e-4)

    # load model and vocabularies if resuming
    if resume:
        if os.path.isfile(resume):
            print("=> loading checkpoint '{}'".format(resume))
            checkpoint = torch.load(resume)
            start_iter = checkpoint['iter_i']
            best_pos_acc = checkpoint['best_pos_acc']
            test_pos_acc = checkpoint['test_pos_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (iter {})".format(
                resume, checkpoint['iter_i']))
        else:
            print("=> no checkpoint found at '{}'".format(resume))

    print_parameters(model)

    # print some stuff just for fun
    logger.info("Most common words: %s" % _form.vocab.freqs.most_common(20))
    logger.info("Word vocab size: %s" % n_words)
    logger.info("Most common XPOS-tags: %s" % _pos.vocab.freqs.most_common())
    logger.info("POS vocab size: %s" % n_tags)
    # logger.info("Most common chars: %s" % _chars.nesting_field.vocab.freqs.most_common())
    logger.info("Chars vocab size: %s" % n_chars)

    print("First training example:")
    print_example(train_dataset[0])

    print("First dev example:")
    print_example(dev_dataset[0])

    print("First test example:")
    print_example(test_dataset[0])

    logger.info("Training starts..")
    upos_var, morph_var = None, None
    for iter_i in range(start_iter, n_iters + 1):

        if not plateau and iter_i % (912344 // batch_size) == 0:
            scheduler.step()
        model.train()

        batch = next(iter(train_iter))
        form_var, lengths = batch.form

        pos_var = batch.pos
        char_var, sentence_lengths, word_lengths = batch.chars
        lengths = lengths.view(-1).tolist()

        result = model(form_var=form_var,
                       char_var=char_var,
                       pos_var=pos_var,
                       lengths=lengths,
                       word_lengths=word_lengths)

        # rows sum to 1
        # print(torch.exp(output_graph).sum(-1))

        # print sizes
        # print(head_logits.data.cpu().size())
        targets = dict(pos=batch.pos)

        all_losses = model.get_loss(scores=result, targets=targets)

        loss = all_losses['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        optimizer.zero_grad()

        if iter_i % print_every == 0:

            # get scores for this batch
            if model.tagger == "linear":
                pos_predictions = result['output'].max(2)[1]
            else:
                pos_predictions = result['sequence']
            predictions = dict(pos=pos_predictions)
            targets = dict(pos=batch.pos)

            pos_acc = model.get_accuracy(predictions=predictions,
                                         targets=targets)

            if not plateau:
                lr = scheduler.get_lr()[0]
            else:
                lr = [group['lr'] for group in optimizer.param_groups][0]

            fmt = "Iter %08d loss %8.4f pos-acc %5.2f lr %.5f"

            logger.info(fmt % (iter_i, loss, pos_acc, lr))

        if iter_i % eval_every == 0:

            # parse dev set and save to file for official evaluation
            dev_out_path = 'dev.iter%08d.conll' % iter_i
            dev_out_path = os.path.join(output_dir, dev_out_path)
            predict_and_save(dataset=dev_dataset,
                             model=model,
                             dataset_path=dev_path,
                             out_path=dev_out_path)

            _dev_pos_acc = get_pos_acc(dev_path, dev_out_path)

            logger.info("Evaluation dev Iter %08d "
                        "pos-acc %5.2f" % (iter_i, _dev_pos_acc))

            # parse test set and save to file for official evaluation
            test_out_path = 'test.iter%08d.conll' % iter_i
            test_out_path = os.path.join(output_dir, test_out_path)
            predict_and_save(dataset=test_dataset,
                             model=model,
                             dataset_path=test_path,
                             out_path=test_out_path)
            _test_pos_acc = get_pos_acc(test_path, test_out_path)

            logger.info("Evaluation test Iter %08d "
                        "pos-acc %5.2f" % (iter_i, _test_pos_acc))

            if plateau:
                scheduler.step(_dev_pos_acc)

            if _dev_pos_acc > best_pos_acc:
                best_iter = iter_i
                best_pos_acc = _dev_pos_acc
                test_pos_acc = _test_pos_acc
                is_best = True
            else:
                is_best = False

            save_checkpoint(
                output_dir, {
                    'iter_i': iter_i,
                    'state_dict': model.state_dict(),
                    'best_iter': best_iter,
                    'test_pos_acc': test_pos_acc,
                    'optimizer': optimizer.state_dict(),
                }, False)

    logger.info("Done Training")
    logger.info(
        "Best model Iter %08d Dev POS-acc %12.4f Test POS-acc %12.4f " %
        (best_iter, best_pos_acc, test_pos_acc))
Esempio n. 9
0
                            n_output=[11, 168, 7, 1295],
                            input_channels=1)
model.cuda()

loaders = collections.OrderedDict()
loaders["train"] = train_loader
loaders["valid"] = val_loader

runner = SupervisedRunner(input_key="image",
                          output_key=None,
                          input_target_key=None)

optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=0.001)

scheduler = ReduceLROnPlateau(optimizer=optimizer,
                              factor=0.75,
                              patience=3,
                              mode="max")

criterions_dict = {
    "vowel_diacritic_loss": torch.nn.CrossEntropyLoss(weight=get_w(ny1)),
    "grapheme_root_loss": torch.nn.CrossEntropyLoss(weight=get_w(ny2)),
    "consonant_diacritic_loss": torch.nn.CrossEntropyLoss(weight=get_w(ny3)),
    "grapheme_loss": torch.nn.CrossEntropyLoss(),
}

callbacks = [
    MixupCutmixCallback(
        fields=["image"],
        output_key=(
            "logit_grapheme_root",
            "logit_vowel_diacritic",
    def __init__(self, model):
        self.fold = args.fold
        self.total_folds = 5
        self.num_workers = 6
        self.batch_size = {
            "train": args.batch_size,
            "val": args.batch_size
        }  # 4
        self.accumulation_steps = 32 // self.batch_size['train']
        self.lr = args.learning_rate
        self.num_epochs = args.epochs
        self.best_loss = float("inf")
        self.best_dice = 0
        self.phases = ["train", "val"]
        self.device = torch.device("cuda:0")
        torch.set_default_tensor_type("torch.cuda.FloatTensor")
        self.net = model
        self.criterion = MixedLoss(10.0, 2.0)

        if args.swa is True:
            # base_opt = torch.optim.SGD(self.net.parameters(), lr=args.max_lr, momentum=args.momentum, weight_decay=args.weight_decay)
            base_opt = RAdam(self.net.parameters(), lr=self.lr)
            self.optimizer = SWA(base_opt,
                                 swa_start=38,
                                 swa_freq=1,
                                 swa_lr=args.min_lr)
            # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, scheduler_step, args.min_lr)
        else:
            if args.optimizer.lower() == 'adam':
                self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
            elif args.optimizer.lower() == 'radam':
                self.optimizer = RAdam(
                    self.net.parameters(), lr=self.lr
                )  # betas=(args.beta1, args.beta2),weight_decay=args.weight_decay
            elif args.optimizer.lower() == 'sgd':
                self.optimizer = torch.optim.SGD(
                    self.net.parameters(),
                    lr=args.max_lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

        if args.scheduler.lower() == 'reducelronplateau':
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               mode="min",
                                               patience=args.patience,
                                               verbose=True)
        elif args.scheduler.lower() == 'clr':
            self.scheduler = CyclicLR(self.optimizer,
                                      base_lr=self.lr,
                                      max_lr=args.max_lr)
        self.net = self.net.to(self.device)
        cudnn.benchmark = True
        self.dataloaders = {
            phase: provider(
                fold=args.fold,
                total_folds=5,
                data_folder=data_folder,
                df_path=train_rle_path,
                phase=phase,
                size=args.img_size_target,
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225),
                batch_size=self.batch_size[phase],
                num_workers=self.num_workers,
            )
            for phase in self.phases
        }
        self.losses = {phase: [] for phase in self.phases}
        self.iou_scores = {phase: [] for phase in self.phases}
        self.dice_scores = {phase: [] for phase in self.phases}
        self.kaggle_metric = {phase: [] for phase in self.phases}
Esempio n. 11
0
def train_bertmhc(args):

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ###############################################################################
    # Load data
    ###############################################################################
    if args.deconvolution:
        trainMa = MAData(args.data + args.train,
                         sa_epochs=args.sa_epoch,
                         calibrate=args.calibrate,
                         negative=args.negative)
        valMa = MAData(args.data + args.eval,
                       sa_epochs=args.sa_epoch,
                       calibrate=args.calibrate,
                       negative=args.negative)
    else:
        trainset = BertDataset(args.data + args.train,
                               max_pep_len=args.peplen,
                               instance_weight=args.instance_weight)
        valset = BertDataset(args.data + args.eval,
                             max_pep_len=args.peplen,
                             instance_weight=args.instance_weight)
        train_data = DataLoader(trainset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=16,
                                pin_memory=True,
                                collate_fn=trainset.collate_fn)
        val_data = DataLoader(valset,
                              batch_size=args.batch_size * 2,
                              num_workers=16,
                              pin_memory=True,
                              collate_fn=valset.collate_fn)
        logger.info("Training on {0} samples, eval on {1}".format(
            len(trainset), len(valset)))

    ################
    # Load model
    ################
    if args.random_init:
        config = ProteinBertConfig.from_pretrained('bert-base')
        model = BERTMHC(config)
    else:
        model = BERTMHC.from_pretrained('bert-base')

    for p in model.bert.parameters():
        p.requires_grad = True

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model = model.to(device)

    # loss
    aff_criterion = nn.BCEWithLogitsLoss()
    w_pos = torch.tensor([args.w_pos]).to(device)
    mass_criterion = nn.BCEWithLogitsLoss(pos_weight=w_pos, reduction='none')

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          nesterov=True)

    scheduler = ReduceLROnPlateau(optimizer,
                                  'max',
                                  patience=2,
                                  min_lr=1e-4,
                                  factor=0.1)

    early_stopping = EarlyStopping(patience=args.patience,
                                   verbose=True,
                                   saveto=args.save)

    for epoch in range(args.epochs):
        if args.deconvolution:
            trainset = BertDataset(trainMa.generate_training(
                model,
                args.peplen,
                score='mass_pred',
                batch_size=args.batch_size * 2),
                                   max_pep_len=args.peplen,
                                   instance_weight=args.instance_weight)
            valset = BertDataset(valMa.generate_training(
                model,
                args.peplen,
                score='mass_pred',
                batch_size=args.batch_size * 2),
                                 max_pep_len=args.peplen,
                                 instance_weight=args.instance_weight)
            train_data = DataLoader(trainset,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=16,
                                    pin_memory=True,
                                    collate_fn=trainset.collate_fn)
            val_data = DataLoader(valset,
                                  batch_size=args.batch_size,
                                  num_workers=16,
                                  pin_memory=True,
                                  collate_fn=valset.collate_fn)
            trainMa.close()
            valMa.close()
            if epoch == trainMa.sa_epochs:
                print('Reset early stopping')
                # reset early stopping and scheduler
                early_stopping.reset()
                scheduler._reset()

        print("Training epoch {}".format(epoch))
        train_metrics = train(model, optimizer, train_data, device,
                              aff_criterion, mass_criterion, args.alpha,
                              scheduler)
        eval_metrics = evaluate(model, val_data, device, aff_criterion,
                                mass_criterion, args.alpha)
        eval_metrics['train_loss'] = train_metrics
        logs = eval_metrics

        scheduler.step(logs.get(args.metric))
        logging.info('Sample dict log: %s' % logs)

        # callbacks
        early_stopping(-logs.get(args.metric), model, optimizer)
        if early_stopping.early_stop or logs.get(args.metric) <= 0:
            if args.deconvolution and not trainMa.train_ma:
                # still training SA only model, now switch to training on MA immediately
                trainMa.train_ma = True
                valMa.train_ma = True
                print("Start training with multi-allele data.")
            else:
                print("Early stopping")
                break
class ResNet3D():
    def __init__(self, nb_epochs, lr, batch_size, resume, start_epoch,
                 evaluate, train_loader, val_loader, multi_gpu):
        self.nb_epochs = nb_epochs
        self.lr = lr
        self.batch_size = batch_size
        self.resume = resume
        self.start_epoch = start_epoch
        self.evaluate = evaluate
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.best_prec1 = 0
        self.multi_gpu = multi_gpu

    def build_model(self):
        print('==> Build model and setup loss and optimizer')
        #build model
        model = resnet34()
        if self.multi_gpu:
            self.model = nn.DataParallel(model).cuda()
        else:
            self.model = model.cuda()
        #print self.model
        #Loss function and optimizer
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         self.lr,
                                         momentum=0.9)
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           'min',
                                           patience=2,
                                           verbose=True)

    def resume_and_evaluate(self):
        if self.resume:
            if os.path.isfile(self.resume):
                print("==> loading checkpoint '{}'".format(self.resume))
                checkpoint = torch.load(self.resume)
                self.start_epoch = checkpoint['epoch']
                self.best_prec1 = checkpoint['best_prec1']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("==> loaded checkpoint '{}' (epoch {}) (best_prec1 {})".
                      format(self.resume, checkpoint['epoch'],
                             self.best_prec1))
            else:
                print("==> no checkpoint found at '{}'".format(self.resume))
        if self.evaluate:
            prec1, val_loss = self.validate_1epoch()

    def run(self):
        self.build_model()
        self.resume_and_evaluate()

        cudnn.benchmark = True
        for self.epoch in range(self.start_epoch, self.nb_epochs):
            print('==> Epoch:[{0}/{1}][training stage]'.format(
                self.epoch, self.nb_epochs))
            self.train_1epoch()
            print('==> Epoch:[{0}/{1}][validation stage]'.format(
                self.epoch, self.nb_epochs))
            prec1, val_loss = self.validate_1epoch()
            self.scheduler.step(val_loss)

            is_best = prec1 > self.best_prec1
            if is_best:
                self.best_prec1 = prec1

            save_checkpoint(
                {
                    'epoch': self.epoch,
                    'state_dict': self.model.state_dict(),
                    'best_prec1': self.best_prec1,
                    'optimizer': self.optimizer.state_dict()
                }, is_best)

    def train_1epoch(self):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        #switch to train mode
        self.model.train()
        end = time.time()
        # mini-batch training
        for i, (data, label) in enumerate(tqdm(self.train_loader)):

            # measure data loading time
            data_time.update(time.time() - end)
            #print data.size(), label.size()

            label_copy = label.cuda(async=True)
            data_var = Variable(data).cuda()
            label_var = Variable(label).cuda()

            # compute output
            output = self.model(data_var)
            loss = self.criterion(output, label_var)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, label_copy, topk=(1, 5))
            losses.update(loss.data[0], data.size(0))
            top1.update(prec1[0], data.size(0))
            top5.update(prec5[0], data.size(0))

            # compute gradient and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

        info = {
            'Epoch': [self.epoch],
            'Batch Time': [round(batch_time.avg, 3)],
            'Data Time': [round(data_time.avg, 3)],
            'Loss': [round(losses.avg, 5)],
            'Prec@1': [round(top1.avg, 4)],
            'Prec@5': [round(top5.avg, 4)]
        }
        record_info(info, 'record/training.csv', 'train')

    def validate_1epoch(self):
        batch_time = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        # switch to evaluate mode
        self.model.eval()
        self.dic_video_level_preds = {}
        end = time.time()
        for i, (keys, data, label) in enumerate(tqdm(self.val_loader)):

            #label = label.cuda(async=True)
            data_var = Variable(data)
            label_var = Variable(label)
            data_var = data_var.cuda()
            label_var = label_var.cuda()

            # compute output
            output = self.model(data_var)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            #Calculate video level prediction
            preds = output.data.cpu().numpy()
            nb_data = preds.shape[0]
            for j in range(nb_data):
                videoName = keys[j].split('/', 1)[0]
                if videoName not in self.dic_video_level_preds.keys():
                    self.dic_video_level_preds[videoName] = preds[j, :]
                else:
                    self.dic_video_level_preds[videoName] += preds[j, :]

        video_top1, video_top5, video_loss = self.frame2_video_level_accuracy()

        info = {
            'Epoch': [self.epoch],
            'Batch Time': [round(batch_time.avg, 3)],
            'Loss': [round(video_loss, 5)],
            'Prec@1': [round(video_top1, 3)],
            'Prec@5': [round(video_top5, 3)]
        }
        record_info(info, 'record/testing.csv', 'test')
        return video_top1, video_loss

    def frame2_video_level_accuracy(self):
        with open(
                '/home/ubuntu/cvlab/pytorch/ucf101_two_stream/dic_video_label.pickle',
                'rb') as f:
            video_label = pickle.load(f)
        f.close()

        dic_video_label = {}
        for video in video_label:
            n, g = video.split('_', 1)
            if n == 'HandStandPushups':
                key = 'HandstandPushups_' + g
            else:
                key = video
            dic_video_label[key] = video_label[video]

        correct = 0
        video_level_preds = np.zeros((len(self.dic_video_level_preds), 101))
        video_level_labels = np.zeros(len(self.dic_video_level_preds))
        ii = 0
        for key in sorted(self.dic_video_level_preds.keys()):
            name = key.split('-', 1)[0]

            preds = self.dic_video_level_preds[name]
            label = int(dic_video_label[name]) - 1

            video_level_preds[ii, :] = preds
            video_level_labels[ii] = label
            ii += 1
            if np.argmax(preds) == (label):
                correct += 1

        #top1 top5
        video_level_labels = torch.from_numpy(video_level_labels).long()
        video_level_preds = torch.from_numpy(video_level_preds).float()

        loss = self.criterion(
            Variable(video_level_preds).cuda(),
            Variable(video_level_labels).cuda())
        top1, top5 = accuracy(video_level_preds,
                              video_level_labels,
                              topk=(1, 5))

        top1 = float(top1.numpy())
        top5 = float(top5.numpy())

        #print(' * Video level Prec@1 {top1:.3f}, Video level Prec@5 {top5:.3f}'.format(top1=top1, top5=top5))
        return top1, top5, loss.data.cpu().numpy()
Esempio n. 13
0
def main(args):
    args.debug = True
    assert args.net_type in ['ff', 'rnn']
    # create data batcher, vocabulary
    # batcher
    with open(join(DATA_DIR, 'vocab_cnt.pkl'), 'rb') as f:
        wc = pkl.load(f)
    word2id = make_vocab(wc, args.vsize)
    train_batcher, val_batcher = build_batchers(args.net_type, word2id,
                                                args.cuda, args.debug)
    #pdb.set_trace()
    # make net
    net, net_args = configure_net(args.net_type, len(word2id), args.emb_dim,
                                  args.conv_hidden, args.lstm_hidden,
                                  args.lstm_layer, args.bi)
    if args.w2v:
        # NOTE: the pretrained embedding having the same dimension
        #       as args.emb_dim should already be trained
        embedding, _ = make_embedding({i: w
                                       for w, i in word2id.items()}, args.w2v)
        net.set_embedding(embedding)

    # configure training setting
    criterion, train_params = configure_training(args.net_type, 'adam',
                                                 args.lr, args.clip,
                                                 args.decay, args.batch,
                                                 dict(word2id))

    # save experiment setting
    if not exists(args.path):
        os.makedirs(args.path)
    with open(join(args.path, 'vocab.pkl'), 'wb') as f:
        pkl.dump(word2id, f, pkl.HIGHEST_PROTOCOL)
    meta = {}
    meta['net'] = 'ml_{}_extractor'.format(args.net_type)
    meta['net_args'] = net_args
    meta['traing_params'] = train_params
    with open(join(args.path, 'meta.json'), 'w') as f:
        json.dump(meta, f, indent=4)

    # prepare trainer
    val_fn = basic_validate(net, criterion)
    grad_fn = get_basic_grad_fn(net, args.clip)
    optimizer = optim.Adam(net.parameters(), **train_params['optimizer'][1])
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  verbose=True,
                                  factor=args.decay,
                                  min_lr=0,
                                  patience=args.lr_p)

    if args.cuda:
        net = net.cuda()
    pipeline = BasicPipeline(meta['net'], net, train_batcher, val_batcher,
                             args.batch, val_fn, criterion, optimizer, grad_fn)
    trainer = BasicTrainer(pipeline, args.path, args.ckpt_freq, args.patience,
                           scheduler)

    print('start training with the following hyper-parameters:')
    print(meta)
    trainer.train()
Esempio n. 14
0
    # model
    model = Net(nntype=args.nntype, 
                nunits=args.nunits, 
                nhidden=args.nhidden, 
                sigma=args.sigma).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#, momentum=0.5)
    if args.loss == 'Hinge':
        criterion = nn.MultiLabelSoftMarginLoss()
    elif args.loss == 'CrossEntropy':
        criterion = nn.CrossEntropyLoss()
    else:
        print('Unsupported loss')
        sys.exit()
    scheduler = ReduceLROnPlateau(optimizer, mode='min', 
                            factor=0.1, 
                            patience=args.lr_patience, 
                            verbose=True,
                            min_lr = args.min_lr)

    print(model)
    # for name,val in model.named_parameters():
    #     print(name, val)
    # data
    batch_size = args.batch_size
    train_dataset = datasets.MNIST('./data', 
                                train=True, 
                                download=True, 
                                transform=transforms.ToTensor())
    validation_dataset = datasets.MNIST('./data', 
                                        train=False, 
                                        transform=transforms.ToTensor())
Esempio n. 15
0
def train(config: Dict):
    device = torch.device(
        f"cuda:{config['device']}" if torch.cuda.is_available() else 'cpu')
    config['device'] = device
    # load the data

    if config['random_train_val'] == 'True':
        x, y = load_data(config['total_path'], config['tag_path'])
        x_train, x_test, y_train, y_test = train_test_split(
            x, y, test_size=config['test_ratio'])
        train_ratio = config['train_ratio'] / (1 - config['test_ratio'])
        x_train, x_val, y_train, y_val = train_test_split(
            x, y, train_size=train_ratio)
    else:
        x_train, y_train, x_val, y_val = load_train_val(config)

    train_data = CCGBankData(x_train, y_train)
    val_data = CCGBankData(x_val, y_val)

    train_loader = DataLoader(train_data,
                              batch_size=config['batch_size'],
                              shuffle=True,
                              collate_fn=collate_fn)
    val_loader = DataLoader(val_data,
                            batch_size=config['validate_batch_size'],
                            shuffle=True,
                            collate_fn=collate_fn)

    model = BiLSTM(config).to(device)
    criterion = nn.CrossEntropyLoss()
    all_parameters = [{'params': model.parameters(), 'weight_decay': 0.0}]
    optimizer = optim.AdamW(all_parameters,
                            lr=float(config['lr']),
                            weight_decay=float(config['weight_decay']))

    model.train()

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  factor=0.1,
                                  patience=5,
                                  verbose=True)

    for epoch in range(config['epoch_num']):
        correct, total = 0, 0
        model.train()
        model.bert_model.train()
        for step, batch in enumerate(train_loader):
            optimizer.zero_grad()
            x, y = batch
            if len(x[0]) != len(y[0]):
                print(x[0], y[0])
                raise SimpleCCGException('unmatched x and y!')
            labels = cat_labels(y).to(device)
            total += labels.size(0)
            output, each_len = model(x)
            # example_len = each_len[0]
            # print(f'output size: {output.size()}')
            # print(f'labels size: {labels.size()}')
            _, predicted = torch.max(output, 1)
            # example = predicted[:example_len]
            try:
                correct += (predicted == labels).sum().item()
                loss = criterion(output, labels)
                loss.backward()
                optimizer.step()
            except RuntimeError:
                print_sentences(x)
                raise SimpleCCGException

        # print(f'predicted: {example}')
        # print(f'actual: {labels[:example_len]}')
        print(
            f'epoch {epoch}: loss: {loss.item()} accuracy: {correct/total:.2f}'
        )
        scheduler.step(loss)
        validate(config, val_loader, model, epoch)
        if (epoch + 1) % 5 == 0:
            torch.save(
                {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss
                }, '{}/bert_bilstm_{}_{}_{}.pt'.format(config['model_path'],
                                                       config['lr'], 'Adam',
                                                       epoch + 1))
Esempio n. 16
0
def run(args):
    epochs = args.e
    patience = args.p
    lr_patience = args.lp
    lr = args.lr
    batch_size = args.b
    data_dir = args.dir
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    train_loader, valid_loader, test_loader = get_loaders(batch_size, data_dir)
    model = vgg16_bn(num_classes=10) #VGG('VGG16')
    if args.adam:
        print('Using adam optimizer.')
        optim = Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    else:
        print('Using SGD optimizer.')
        optim = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    trainer = create_supervised_trainer_with_metrics(model, optim, F.cross_entropy,
                                                     metrics={'accuracy': Accuracy(),
                                                              'loss': Loss(F.cross_entropy)},
                                                     device=device)
    evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(),
                                                            'loss': Loss(F.cross_entropy)},
                                            device=device)

    evaluator.register_events('validation_completed')

    # Check early stopping conditions after validation is completed
    def score_function(engine):
        val_loss = engine.state.metrics['loss']
        return -val_loss

    handler = EarlyStopping(patience=patience, score_function=score_function, trainer=trainer)
    evaluator.add_event_handler('validation_completed', handler)

    # Save model checkpoints
    handler = ModelCheckpoint(args.dir, 'cifar10', save_interval=10)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {'model': model, 'optim': optim})

    handler = ModelCheckpoint(args.dir, 'cifar10_end', save_interval=1)
    # trainer.add_event_handler(Events.EXCEPTION_RAISED, handler, {'model': model, 'optim': optim})
    trainer.add_event_handler(Events.COMPLETED, handler, {'model': model, 'optim': optim})

    # Setup timer
    timer = Timer(average=True)
    timer.attach(trainer, step=Events.EPOCH_COMPLETED)

    epoch_timer = Timer(average=False)
    epoch_timer.attach(trainer, start=Events.EPOCH_STARTED, pause=Events.EPOCH_COMPLETED,)

    # Set up learning rate scheduling
    if lr_patience != 0:
        scheduler = ReduceLROnPlateau(optim, patience=lr_patience, verbose=True)
        @evaluator.on('validation_completed')
        def scheduler_step(engine):
            scheduler.step(engine.state.metrics['loss'])

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_validation_results(engine):
        step = engine.state.epoch
        evaluate(trainer, step=step)
        evaluator.run(valid_loader)
        evaluate(evaluator, step=step, prefix='valid_')
        evaluator.fire_event('validation_completed')
        print_floyd_metric('zz_time', epoch_timer.value(), step)

    @trainer.on(Events.COMPLETED)
    def log_test_results(engine):
        step = engine.state.epoch
        evaluator.run(test_loader)
        evaluate(evaluator, step=step, prefix='test_')


    trainer.run(train_loader, max_epochs=epochs)
    total_time = timer.total / 3600
    print(f'Total training time: {total_time:.2f}h')
    print(f'Average time per epoch: {timer.value()}s')
#%%
G = UnetClass().to(gpu)
#G=SmallModel1().to(gpu)

GV = SmallModel().to(gpu)
#G.load_state_dict(torch.load('wts-10U2.pt'))
#GV.load_state_dict(torch.load('wts-10V2.pt'))
#GV.load_state_dict(torch.load('./PTmodels/27Oct_112451am_500ep_27oct/wts-500.pt'))

#optimizer=torch.optim.SGD([{'params':G.parameters(),'lr':5e-3,'momentum':0.9}])
optimizer = torch.optim.AdamW([{'params': G.parameters(), 'lr': 1e-3}])

#optimizer=torch.optim.AdamW([{'params':G.parameters(),'lr':1e-3},{'params':GV.parameters(),'lr':1e-3}])
scheduler = ReduceLROnPlateau(optimizer,
                              mode='min',
                              factor=0.7,
                              patience=6,
                              verbose=True,
                              min_lr=5e-5)

trnFiles = os.listdir('/Shared/lss_jcb/abdul/prashant_cardiac_data/Data/d2/')
sz = len(trnFiles)
#data2=np.zeros((sz,1,n_select,N,N)).astype(np.complex64)

rndm = random.sample(range(sz), sz)
#%%
nuf_ob = KbNufft(im_size=(nx, nx), norm='ortho').to(dtype)
nuf_ob = nuf_ob.to(gpu)

adjnuf_ob = AdjKbNufft(im_size=(nx, nx), norm='ortho').to(dtype)
adjnuf_ob = adjnuf_ob.to(gpu)
Esempio n. 18
0
class UNetExperiment(PytorchExperiment):
    """
    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    It is optimized to work with the provided NumpyDataLoader.

    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:

        setup()
        (--> Automatically restore values if a previous checkpoint is given)
        prepare()

        for epoch in n_epochs:
            train()
            validate()
            (--> save current checkpoint)

        end()
    """
    def setup(self):
        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        test_keys = splits[self.config.fold]['test']

        self.device = torch.device(
            self.config.device if torch.cuda.is_available() else "cpu")

        self.train_data_loader = NumpyDataSet(
            self.config.data_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=tr_keys)
        self.val_data_loader = NumpyDataSet(self.config.data_dir,
                                            target_size=self.config.patch_size,
                                            batch_size=self.config.batch_size,
                                            keys=val_keys,
                                            mode="val",
                                            do_reshuffle=False)
        self.test_data_loader = NumpyDataSet(
            self.config.data_test_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=test_keys,
            mode="test",
            do_reshuffle=False)
        self.model = UNet(num_classes=self.config.num_classes,
                          in_channels=self.config.in_channels)

        self.model.to(self.device)

        # We use a combination of DICE-loss and CE-Loss in this example.
        # This proved good in the medical segmentation decathlon.
        self.dice_loss = SoftDiceLoss(
            batch_dice=True)  # Softmax for DICE Loss!
        self.ce_loss = torch.nn.CrossEntropyLoss(
        )  # No softmax for CE Loss -> is implemented in torch!

        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate)

        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print('Checkpoint_dir is empty, training from scratch.')
            else:
                self.load_checkpoint(name=self.config.checkpoint_filename,
                                     save_types=("model"),
                                     path=self.config.checkpoint_dir)

            if self.config.fine_tune in ['expanding_all', 'expanding_plus1']:
                # freeze part of the network, fine-tune the other part
                unfreeze_block_parameters(
                    model=self.model, fine_tune_option=self.config.fine_tune)
                # else just train the whole network

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')

    # overloaded method from the base class PytorchExperiment
    def load_checkpoint(self,
                        name="checkpoint",
                        save_types=("model", "optimizer", "simple", "th_vars",
                                    "results"),
                        n_iter=None,
                        iter_format="{:05d}",
                        prefix=False,
                        path=None):
        """
        Loads a checkpoint and restores the experiment.
        Make sure you have your torch stuff already on the right devices beforehand,
        otherwise this could lead to errors e.g. when making a optimizer step
        (and for some reason the Adam states are not already on the GPU:
        https://discuss.pytorch.org/t/loading-a-saved-model-for-continue-training/17244/3 )
        Args:
            name (str): The name of the checkpoint file
            save_types (list or tuple): What kind of member variables should be loaded? Choices are:
                "model" <-- Pytorch models,
                "optimizer" <-- Optimizers,
                "simple" <-- Simple python variables (basic types and lists/tuples),
                "th_vars" <-- torch tensors,
                "results" <-- The result dict
            n_iter (int): Number of iterations. Together with the name, defined by the iter_format,
                a file name will be created and searched for.
            iter_format (str): Defines how the name and the n_iter will be combined.
            prefix (bool): If True, the formatted n_iter will be prepended, otherwise appended.
            path (str): If no path is given then it will take the current experiment dir and formatted
                name, otherwise it will simply use the path and the formatted name to define the
                checkpoint file.
        """
        if self.elog is None:
            return

        model_dict = {}
        optimizer_dict = {}
        simple_dict = {}
        th_vars_dict = {}
        results_dict = {}

        if "model" in save_types:
            model_dict = self.get_pytorch_modules()
        if "optimizer" in save_types:
            optimizer_dict = self.get_pytorch_optimizers()
        if "simple" in save_types:
            simple_dict = self.get_simple_variables()
        if "th_vars" in save_types:
            th_vars_dict = self.get_pytorch_variables()
        if "results" in save_types:
            results_dict = {"results": self.results}

        checkpoint_dict = {
            **model_dict,
            **optimizer_dict,
            **simple_dict,
            **th_vars_dict,
            **results_dict
        }

        if n_iter is not None:
            name = name_and_iter_to_filename(name,
                                             n_iter,
                                             ".pth.tar",
                                             iter_format=iter_format,
                                             prefix=prefix)

        # Jorg Begin
        # if self.config.dont_load_lastlayer:
        #     exclude_layer_dict = {'model': ['model.model.5.weight', 'model.model.5.bias']}
        # else:
        #     exclude_layer_dict = {}
        exclude_layer_dict = {}
        # Jorg End

        if path is None:
            restore_dict = self.elog.load_checkpoint(name=name,
                                                     **checkpoint_dict)
        else:
            checkpoint_path = os.path.join(path, name)
            if checkpoint_path.endswith("/"):
                checkpoint_path = checkpoint_path[:-1]
            restore_dict = self.elog.load_checkpoint_static(
                checkpoint_file=checkpoint_path,
                exclude_layer_dict=exclude_layer_dict,
                **checkpoint_dict)

        self.update_attributes(restore_dict)

    def train(self, epoch):
        self.elog.print('=====TRAIN=====')
        self.model.train()

        data = None
        batch_counter = 0
        for data_batch in self.train_data_loader:

            self.optimizer.zero_grad()

            # Shape of data_batch = [1, b, c, w, h]
            # Desired shape = [b, c, w, h]
            # Move data and target to the GPU
            data = data_batch['data'][0].float().to(self.device)
            target = data_batch['seg'][0].long().to(self.device)

            pred = self.model(data)
            pred_softmax = F.softmax(
                pred, dim=1
            )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

            loss = self.dice_loss(pred_softmax,
                                  target.squeeze()) + self.ce_loss(
                                      pred, target.squeeze())

            loss.backward()
            self.optimizer.step()

            # Some logging and plotting
            if (batch_counter % self.config.plot_freq) == 0:
                self.elog.print('Epoch: {0} Loss: {1:.4f}'.format(
                    self._epoch_idx, loss))

                self.add_result(
                    value=loss.item(),
                    name='Train_Loss',
                    tag='Loss',
                    counter=epoch +
                    (batch_counter /
                     self.train_data_loader.data_loader.num_batches))

                self.clog.show_image_grid(data.float().cpu(),
                                          name="data",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)
                self.clog.show_image_grid(target.float().cpu(),
                                          name="mask",
                                          title="Mask",
                                          n_iter=epoch)
                self.clog.show_image_grid(torch.argmax(pred.cpu(),
                                                       dim=1,
                                                       keepdim=True),
                                          name="unt_argmax",
                                          title="Unet",
                                          n_iter=epoch)
                self.clog.show_image_grid(pred.cpu()[:, 1:2, ],
                                          name="unt",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)

            batch_counter += 1

        assert data is not None, 'data is None. Please check if your dataloader works properly'

    def validate(self, epoch):
        self.elog.print('VALIDATE')
        self.model.eval()

        data = None
        loss_list = []

        with torch.no_grad():
            for data_batch in self.val_data_loader:
                data = data_batch['data'][0].float().to(self.device)
                target = data_batch['seg'][0].long().to(self.device)

                pred = self.model(data)
                pred_softmax = F.softmax(
                    pred, dim=1
                )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

                loss = self.dice_loss(pred_softmax,
                                      target.squeeze()) + self.ce_loss(
                                          pred, target.squeeze())
                loss_list.append(loss.item())

        assert data is not None, 'data is None. Please check if your dataloader works properly'
        self.scheduler.step(np.mean(loss_list))

        self.elog.print('Epoch: %d Loss: %.4f' %
                        (self._epoch_idx, np.mean(loss_list)))

        self.add_result(value=np.mean(loss_list),
                        name='Val_Loss',
                        tag='Loss',
                        counter=epoch + 1)

        self.clog.show_image_grid(data.float().cpu(),
                                  name="data_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)
        self.clog.show_image_grid(target.float().cpu(),
                                  name="mask_val",
                                  title="Mask",
                                  n_iter=epoch)
        self.clog.show_image_grid(torch.argmax(pred.data.cpu(),
                                               dim=1,
                                               keepdim=True),
                                  name="unt_argmax_val",
                                  title="Unet",
                                  n_iter=epoch)
        self.clog.show_image_grid(pred.data.cpu()[:, 1:2, ],
                                  name="unt_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)

    def test(self):
        from evaluation.evaluator import aggregate_scores, Evaluator
        from collections import defaultdict

        self.elog.print('=====TEST=====')
        self.model.eval()

        pred_dict = defaultdict(list)
        gt_dict = defaultdict(list)

        batch_counter = 0

        if self.config.visualize_segm:
            color_class_converter = LabelTensorToColor()

        with torch.no_grad():
            for data_batch in self.test_data_loader:
                print('testing...', batch_counter)
                batch_counter += 1

                # Get data_batches
                mr_data = data_batch['data'][0].float().to(self.device)
                mr_target = data_batch['seg'][0].float().to(self.device)

                pred = self.model(mr_data)
                pred_argmax = torch.argmax(pred.data.cpu(),
                                           dim=1,
                                           keepdim=True)

                fnames = data_batch['fnames']
                for i, fname in enumerate(fnames):
                    pred_dict[fname[0]].append(
                        pred_argmax[i].detach().cpu().numpy())
                    gt_dict[fname[0]].append(
                        mr_target[i].detach().cpu().numpy())

                if batch_counter == 35 and self.config.visualize_segm:
                    segm_visualization(mr_data, mr_target, pred_argmax,
                                       color_class_converter, self.config)

        test_ref_list = []
        for key in pred_dict.keys():
            test_ref_list.append(
                (np.stack(pred_dict[key]), np.stack(gt_dict[key])))

        scores = aggregate_scores(test_ref_list,
                                  evaluator=Evaluator,
                                  json_author=self.config.author,
                                  json_task=self.config.name,
                                  json_name=self.config.name,
                                  json_output_file=self.elog.work_dir +
                                  "/{}_".format(self.config.author) +
                                  self.config.name + '.json')

        self.scores = scores

        print("Scores:\n", scores)
Esempio n. 19
0
    def setup(self):
        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        test_keys = splits[self.config.fold]['test']

        self.device = torch.device(
            self.config.device if torch.cuda.is_available() else "cpu")

        self.train_data_loader = NumpyDataSet(
            self.config.data_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=tr_keys)
        self.val_data_loader = NumpyDataSet(self.config.data_dir,
                                            target_size=self.config.patch_size,
                                            batch_size=self.config.batch_size,
                                            keys=val_keys,
                                            mode="val",
                                            do_reshuffle=False)
        self.test_data_loader = NumpyDataSet(
            self.config.data_test_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=test_keys,
            mode="test",
            do_reshuffle=False)
        self.model = UNet(num_classes=self.config.num_classes,
                          in_channels=self.config.in_channels)

        self.model.to(self.device)

        # We use a combination of DICE-loss and CE-Loss in this example.
        # This proved good in the medical segmentation decathlon.
        self.dice_loss = SoftDiceLoss(
            batch_dice=True)  # Softmax for DICE Loss!
        self.ce_loss = torch.nn.CrossEntropyLoss(
        )  # No softmax for CE Loss -> is implemented in torch!

        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate)

        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print('Checkpoint_dir is empty, training from scratch.')
            else:
                self.load_checkpoint(name=self.config.checkpoint_filename,
                                     save_types=("model"),
                                     path=self.config.checkpoint_dir)

            if self.config.fine_tune in ['expanding_all', 'expanding_plus1']:
                # freeze part of the network, fine-tune the other part
                unfreeze_block_parameters(
                    model=self.model, fine_tune_option=self.config.fine_tune)
                # else just train the whole network

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')
Esempio n. 20
0
    def train(self, epochs, tolerance=1e-8, patience=5):
        """ Train model
		
		Inputs

		EPOCHS (int) 		max number of epochs to train
		TOLERANCE (float) 	minimal decrease in validation loss tolerated
		PATIENCE (int)		if validation loss does not decrease by at least TOLERANCE in PATENCE epochs, training will stop early

		"""

        # grad_accum_step = 5
        # _grad_accum = grad_accum_step

        _patience = patience
        _epoch = 0
        best_val_loss = float("inf")

        scheduler = ReduceLROnPlateau(self.optimizer,
                                      factor=0.5,
                                      patience=5,
                                      threshold=1e-4,
                                      verbose=True)

        while _epoch < epochs and _patience > 0:

            self.model.train()

            epoch_loss = 0
            batch = 0
            self.loader.reset()

            # iterate through batches
            while self.loader.has_next():
                batch += 1
                print("epoch {} batch {}".format(_epoch, batch), end='\r')

                # fetch next batch of training samples
                pos, neg = self.loader.next_batch()

                # do any normalization before each mini-batch
                self.model.norm_step()

                # run trainins samples through network
                # posScore, negScore = self.model(pos[:,0], pos[:,1], pos[:,2], pos[:,3],
                # 					neg[:,0], neg[:,1], neg[:,2], neg[:,3])
                # calculate batch loss
                # tmpTensor = torch.tensor([-1], dtype=torch.float).to(device)
                # batch_loss = self.criterion(posScore, negScore, tmpTensor)
                # epoch_loss += batch_loss

                # new implementation - forward pass returns batch loss
                batch_loss = self.model(pos[:, 0], pos[:, 1], pos[:, 2],
                                        pos[:, 3], neg[:, 0], neg[:, 1],
                                        neg[:, 2], neg[:, 3])
                epoch_loss += batch_loss

                # backpropagate
                batch_loss.backward()
                self.optimizer.step()

                # reset gradients
                #   (we do gradient accumulation within each batch, but need
                #   to reset the gradients at the beginning of each batch)
                self.optimizer.zero_grad()

            # Average epoch loss over all training samples
            #     -- assume self.criterion uses reduction='sum'
            epoch_loss = epoch_loss.item()  #/len(self.traindata)

            # Calculate Validation loss at the end of each epoch
            val_loss = self._val_loss()

            # update learning rate if needed
            scheduler.step(val_loss)

            # if improvement is not large enough:
            if (best_val_loss - val_loss) < tolerance:
                _patience -= 1

            # Save model checkpoint if best validation loss is achieved
            if val_loss < best_val_loss:
                torch.save(
                    self.model.state_dict(),
                    os.path.join(self.config['name'],
                                 'best_val_loss_state_dict.pt'))
                best_val_loss = val_loss
                _patience = patience

            # Print feedback
            print(
                "epoch {} - loss: {:.8}, val_loss: {:.8}, patience: {}".format(
                    _epoch, epoch_loss, val_loss, _patience))
            self.logger.log(_epoch, {'loss': epoch_loss, 'val_loss': val_loss})

            _epoch += 1
Esempio n. 21
0
class Trainer(BaseTrainer):
    """
    Trainer class

    Note:
        Inherited from BaseTrainer.
        self.optimizer is by default handled by BaseTrainer based on config.
    """
    def __init__(self,
                 model,
                 loss,
                 metrics,
                 resume,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 train_logger=None,
                 writer=None):
        super(Trainer, self).__init__(model, loss, metrics, resume, config,
                                      train_logger)
        self.config = config
        self.batch_size = data_loader.batch_size
        self.data_loader = data_loader
        self.valid_data_loader = valid_data_loader
        self.valid = True if self.valid_data_loader is not None else False
        self.log_step = int(np.sqrt(self.batch_size))
        self.writer = writer
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           factor=0.5,
                                           patience=50,
                                           verbose=True)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.

        Note:
            If you have additional information to record, for example:
                > additional_log = {"x": x, "y": y}
            merge it with log before return. i.e.
                > log = {**log, **additional_log}
                > return log

            The metrics in log must have the key 'metrics'.
        """
        self.model.train()

        total_loss = 0
        total_metrics = np.zeros(len(self.metrics))
        for batch_idx, (data, target) in enumerate(self.data_loader):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(data)

            loss = self.loss(output, target)
            loss.backward()
            self.optimizer.step()

            train_steps = epoch * len(self.data_loader) + batch_idx
            self.writer.add_scalar('train/loss', loss.item(), train_steps)
            acc_metrics = np.zeros(len(self.metrics))
            for i, metric in enumerate(self.metrics):
                acc_metrics[i] += metric(output, target)
                self.writer.add_scalar(f'train/{metric.__name__}',
                                       acc_metrics[i], train_steps)

            if self.verbosity >= 2 and batch_idx % self.log_step == 0:
                self.writer.add_image('train/input',
                                      make_grid(data[:32].cpu(), nrow=4),
                                      train_steps)
                self.writer.add_image('train/target',
                                      make_grid(target[:32].cpu(), nrow=4),
                                      train_steps)
                self.writer.add_image('train/output',
                                      make_grid(output[:32].cpu(), nrow=4),
                                      train_steps)
                self.logger.info(
                    'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(
                        epoch, batch_idx * self.data_loader.batch_size,
                        self.data_loader.n_samples,
                        100.0 * batch_idx / len(self.data_loader),
                        loss.item()))
            total_metrics += acc_metrics
            total_loss += loss.item()

        log = {
            'loss': total_loss / len(self.data_loader),
            'metrics': (total_metrics / len(self.data_loader)).tolist()
        }

        if self.valid:
            val_log = self._valid_epoch(epoch)
            log = {**log, **val_log}

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :return: A log that contains information about validation

        Note:
            The validation metrics in log must have the key 'val_metrics'.
        """
        self.model.eval()
        total_val_loss = 0
        total_val_metrics = np.zeros(len(self.metrics))
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.loss(output, target)

                valid_steps = epoch * len(self.valid_data_loader) + batch_idx
                self.writer.add_scalar('valid/loss', loss.item(), valid_steps)
                acc_metrics = np.zeros(len(self.metrics))
                for i, metric in enumerate(self.metrics):
                    acc_metrics[i] += metric(output, target)
                    self.writer.add_scalar(f'valid/{metric.__name__}',
                                           acc_metrics[i], valid_steps)

                self.writer.add_image('valid/input',
                                      make_grid(data[:32].cpu(), nrow=4),
                                      valid_steps)
                self.writer.add_image('valid/target',
                                      make_grid(target[:32].cpu(), nrow=4),
                                      valid_steps)
                self.writer.add_image('valid/output',
                                      make_grid(output[:32].cpu(), nrow=4),
                                      valid_steps)
                total_val_loss += loss.item()
                total_val_metrics += acc_metrics

                self.scheduler.step(loss.item())

        return {
            'val_loss':
            total_val_loss / len(self.valid_data_loader),
            'val_metrics':
            (total_val_metrics / len(self.valid_data_loader)).tolist()
        }
Esempio n. 22
0
def trainer(model,
            optimizer,
            train_loader,
            test_loader,
            epochs=5,
            gpus=1,
            tasks=1,
            classifacation=False,
            mae=False,
            pb=True,
            out="model.pt",
            cyclic=False,
            verbose=True):
    device = next(model.parameters()).device
    if classifacation:
        tracker = trackers.ComplexPytorchHistory(
        ) if tasks > 1 else trackers.PytorchHistory(
            metric=metrics.roc_auc_score, metric_name='roc-auc')
    else:
        tracker = trackers.ComplexPytorchHistory(
        ) if tasks > 1 else trackers.PytorchHistory()

    earlystopping = EarlyStopping(patience=50, delta=1e-5)
    if cyclic:
        lr_red = CosineAnnealingWarmRestarts(optimizer, T_0=20)
    else:
        lr_red = ReduceLROnPlateau(optimizer,
                                   mode='min',
                                   factor=0.8,
                                   patience=20,
                                   cooldown=0,
                                   verbose=verbose,
                                   threshold=1e-4,
                                   min_lr=1e-8)

    for epochnum in range(epochs):
        train_loss = 0
        test_loss = 0
        train_iters = 0
        test_iters = 0
        model.train()
        if pb:
            gen = tqdm(enumerate(train_loader))
        else:
            gen = enumerate(train_loader)
        for i, (drugfeats, value) in gen:
            optimizer.zero_grad()
            drugfeats, value = drugfeats.to(device), value.to(device)
            pred, attn = model(drugfeats)

            if classifacation:
                mse_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    pred, value).mean()
            elif mae:
                mse_loss = torch.nn.functional.l1_loss(pred, value).mean()
            else:
                mse_loss = torch.nn.functional.mse_loss(pred, value).mean()
            mse_loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), 10.0)
            optimizer.step()
            train_loss += mse_loss.item()
            train_iters += 1
            tracker.track_metric(pred=pred.detach().cpu().numpy(),
                                 value=value.detach().cpu().numpy())

        tracker.log_loss(train_loss / train_iters, train=True)
        tracker.log_metric(internal=True, train=True)

        model.eval()
        with torch.no_grad():
            for i, (drugfeats, value) in enumerate(test_loader):
                drugfeats, value = drugfeats.to(device), value.to(device)
                pred, attn = model(drugfeats)

                if classifacation:
                    mse_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                        pred, value).mean()
                elif mae:
                    mse_loss = torch.nn.functional.l1_loss(pred, value).mean()
                else:
                    mse_loss = torch.nn.functional.mse_loss(pred, value).mean()
                test_loss += mse_loss.item()
                test_iters += 1
                tracker.track_metric(pred.detach().cpu().numpy(),
                                     value.detach().cpu().numpy())
        tracker.log_loss(train_loss / train_iters, train=False)
        tracker.log_metric(internal=True, train=False)

        lr_red.step(test_loss / test_iters)
        earlystopping(test_loss / test_iters)
        if verbose:
            print("Epoch", epochnum, train_loss / train_iters,
                  test_loss / test_iters, tracker.metric_name,
                  tracker.get_last_metric(train=True),
                  tracker.get_last_metric(train=False))

        if out is not None:
            if gpus == 1:
                state = model.state_dict()
                heads = model.nheads
            else:
                state = model.module.state_dict()
                heads = model.module.nheads
            torch.save(
                {
                    'model_state': state,
                    'opt_state': optimizer.state_dict(),
                    'history': tracker,
                    'nheads': heads,
                    'ntasks': tasks
                }, out)
        if earlystopping.early_stop:
            break
    return model, tracker
Esempio n. 23
0
def train(cfg,
          dataset_train=None,
          dataset_valid=None,
          dataset_test=None,
          recompile=True):

    print("Our config:")
    pprint.pprint(cfg)

    # Get information from configuration.
    seed = cfg['seed']
    cuda = cfg['cuda']
    num_epochs = cfg['num_epochs']
    exp_name = cfg['experiment_name']
    recon_masked = cfg['recon_masked']
    recon_continuous = cfg['recon_continuous']

    device = 'cuda' if cuda else 'cpu'

    # Setting the seed.
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if cuda:
        torch.cuda.manual_seed_all(seed)

    # Dataset
    # transform
    tr_train = configuration.setup_transform(cfg, 'train')
    tr_valid = configuration.setup_transform(cfg, 'valid')
    tr_test = configuration.setup_transform(cfg, 'test')

    # The dataset
    if recompile:
        dataset_train = configuration.setup_dataset(cfg, 'train')(tr_train)
        dataset_valid = configuration.setup_dataset(cfg, 'valid')(tr_valid)
        dataset_test = configuration.setup_dataset(cfg, 'test')(tr_test)

    # Dataloader
    train_loader = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=cfg['batch_size'],
                                               shuffle=cfg['shuffle'],
                                               num_workers=0,
                                               pin_memory=cuda)
    valid_loader = torch.utils.data.DataLoader(dataset_valid,
                                               batch_size=cfg['batch_size'],
                                               shuffle=cfg['shuffle'],
                                               num_workers=0,
                                               pin_memory=cuda)
    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size=cfg['batch_size'],
                                              shuffle=cfg['shuffle'],
                                              num_workers=0,
                                              pin_memory=cuda)

    model = configuration.setup_model(cfg).to(device)
    print(model)
    # TODO: checkpointing

    # Optimizer
    optim = configuration.setup_optimizer(cfg)(model.parameters())
    scheduler = ReduceLROnPlateau(optim, mode='max')
    print(optim)

    criterion = torch.nn.CrossEntropyLoss()

    # Stats for the table.
    best_epoch, best_train_auc, best_valid_auc, best_test_auc = -1, -1, -1, -1
    metrics = []
    auc_valid = 0

    # Wrap the function for mlflow (optional).
    valid_wrap_epoch = mlflow_logger.log_metric('valid_acc')(test_epoch)
    test_wrap_epoch = mlflow_logger.log_metric('test_acc')(test_epoch)

    img_viz_train = dataset_train[VIZ_IDX]
    img_viz_valid = dataset_valid[VIZ_IDX]

    print("CUDA: ", cuda)
    for epoch in range(num_epochs):

        # scheduler.step(auc_valid)
        avg_loss = train_epoch(epoch=epoch,
                               model=model,
                               device=device,
                               optimizer=optim,
                               train_loader=train_loader,
                               criterion=criterion,
                               bre_lambda=cfg['bre_lambda'],
                               recon_lambda=cfg['recon_lambda'],
                               actdiff_lambda=cfg['actdiff_lambda'],
                               gradmask_lambda=cfg['gradmask_lambda'],
                               recon_masked=recon_masked,
                               recon_continuous=recon_continuous)

        auc_train = valid_wrap_epoch(name="train",
                                     epoch=epoch,
                                     model=model,
                                     device=device,
                                     data_loader=train_loader,
                                     criterion=criterion)

        auc_valid = valid_wrap_epoch(name="valid",
                                     epoch=epoch,
                                     model=model,
                                     device=device,
                                     data_loader=valid_loader,
                                     criterion=criterion)

        # Early Stopping: compute best test_auc when we beat best valid score.
        if auc_valid > best_valid_auc:

            auc_test = test_wrap_epoch(name="test",
                                       epoch=epoch,
                                       model=model,
                                       device=device,
                                       data_loader=test_loader,
                                       criterion=criterion)

            best_train_auc = auc_train
            best_valid_auc = auc_valid
            best_test_auc = auc_test
            best_epoch = epoch
            best_model = copy.deepcopy(model)

        # Update the stat dictionary with each epoch, append to metrics list.
        stat = {
            "epoch": epoch,
            "train_loss": avg_loss,
            "valid_auc": auc_valid,
            "train_auc": auc_train,
            "test_auc": auc_test,
            "best_train_auc": best_train_auc,
            "best_valid_auc": best_valid_auc,
            "best_test_auc": best_test_auc,
            "best_epoch": best_epoch
        }
        stat.update(configuration.process_config(cfg))
        metrics.append(stat)

    monitoring.log_experiment_csv(cfg, [best_valid_auc])

    results_dict = {
        'dataset_train': dataset_train,
        'dataset_valid': dataset_valid,
        'dataset_test': dataset_test
    }

    # Render gradients from the best model.
    render_img("best_train", best_epoch, img_viz_train, best_model, exp_name,
               cuda)
    render_img("best_valid", best_epoch, img_viz_valid, best_model, exp_name,
               cuda)

    # Write best model to disk.
    output_dir = os.path.join('checkpoints', exp_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Saves maxmasks and seed in the name.
    output_name = "best_model_{}_{}.pth.tar".format(cfg['seed'],
                                                    cfg['maxmasks_train'])
    torch.save(best_model, os.path.join(output_dir, output_name))

    # Save latest model as well.
    output_name = "latest_model_{}_{}.pth.tar".format(cfg['seed'],
                                                      cfg['maxmasks_train'])
    torch.save(model, os.path.join(output_dir, output_name))

    return (best_valid_auc, best_test_auc, metrics, results_dict)
Esempio n. 24
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    long_dtype, float_dtype = get_dtypes(args)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    logger.info('The length of training data is {}'.format(len(train_dset)))
    iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps

    if args.num_epochs:
        args.num_iterations = int(iterations_per_epoch * args.num_epochs)

    logger.info(
        'There are {} iterations per epoch'.format(iterations_per_epoch)
    )

    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)

    generator.apply(init_weights)
    generator.type(float_dtype).train()
    logger.info('Here is the generator:')
    logger.info(generator)

    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type=args.d_type)

    discriminator.apply(init_weights)
    discriminator.type(float_dtype).train()
    logger.info('Here is the discriminator:')
    logger.info(discriminator)

    g_loss_fn = gan_g_loss
    d_loss_fn = gan_d_loss

    optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
    optimizer_d = optim.Adam(
        discriminator.parameters(), lr=args.d_learning_rate
    )
    scheduler_g = ReduceLROnPlateau(optimizer_g, mode='min', factor=0.2, patience=10, verbose=True)
    scheduler_d = ReduceLROnPlateau(optimizer_d, mode='min', factor=0.2, patience=3, verbose=True)
    # Maybe restore from checkpoint
    restore_path = None
    if args.checkpoint_start_from is not None:
        restore_path = args.checkpoint_start_from
    elif args.restore_from_checkpoint == 1:
        restore_path = os.path.join(args.output_dir,
                                    '%s_with_model.pt' % args.checkpoint_name)

    if restore_path is not None and os.path.isfile(restore_path):
        logger.info('Restoring from checkpoint {}'.format(restore_path))
        checkpoint = torch.load(restore_path)
        generator.load_state_dict(checkpoint['g_state'])
        discriminator.load_state_dict(checkpoint['d_state'])
        optimizer_g.load_state_dict(checkpoint['g_optim_state'])
        optimizer_d.load_state_dict(checkpoint['d_optim_state'])
        t = checkpoint['counters']['t']
        epoch = checkpoint['counters']['epoch']
        checkpoint['restore_ts'].append(t)
    else:
        # Starting from scratch, so initialize checkpoint data structure
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'G_losses': defaultdict(list),
            'D_losses': defaultdict(list),
            'losses_ts': [],
            'metrics_val': defaultdict(list),
            'metrics_train': defaultdict(list),
            'sample_ts': [],
            'restore_ts': [],
            'norm_g': [],
            'norm_d': [],
            'genoutputs': [],
            'gt': [],
            'w_loss_g': [],
            'w_loss_d':[],
            'counters': {
                't': None,
                'epoch': None,
            },
            'g_state': None,
            'g_optim_state': None,
            'd_state': None,
            'd_optim_state': None,
            'g_best_state': None,
            'd_best_state': None,
            'best_t': None,
            'g_best_nl_state': None,
            'd_best_state_nl': None,
            'best_t_nl': None,
        }
    t0 = None
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1
        logger.info('Starting epoch {}'.format(epoch))
        for batch in train_loader:
            if args.timing == 1:
                torch.cuda.synchronize()
                t1 = time.time()

            # Decide whether to use the batch for stepping on discriminator or
            # generator; an iteration consists of args.d_steps steps on the
            # discriminator followed by args.g_steps steps on the generator.
            if d_steps_left > 0:
                step_type = 'd'
                losses_d,w_loss_d,fake_trajs,real_trajs = discriminator_step(args, batch, generator,
                                              discriminator, d_loss_fn,
                                                                    optimizer_d,scheduler_d)
                checkpoint['norm_d'].append(
                    get_total_norm(discriminator.parameters()))
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g, w_loss_g = generator_step(args, batch, generator,
                                          discriminator, g_loss_fn,
                                          optimizer_g,scheduler_g)
                checkpoint['norm_g'].append(
                    get_total_norm(generator.parameters())
                )
                g_steps_left -= 1

            if args.timing == 1:
                torch.cuda.synchronize()
                t2 = time.time()
                logger.info('{} step took {}'.format(step_type, t2 - t1))

            # Skip the rest if we are not at the end of an iteration
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            if args.timing == 1:
                if t0 is not None:
                    logger.info('Interation {} took {}'.format(
                        t - 1, time.time() - t0
                    ))
                t0 = time.time()

            # Maybe save loss
            if t % args.print_every == 0:
                logger.info('w_loss_d: {}'.format(w_loss_d))
                logger.info('w_loss_g: {}'.format(w_loss_g))
                logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
                for k, v in sorted(losses_d.items()):
                    logger.info('  [D] {}: {:.3f}'.format(k, v))
                    checkpoint['D_losses'][k].append(v)
                for k, v in sorted(losses_g.items()):
                    logger.info('  [G] {}: {:.3f}'.format(k, v))
                    checkpoint['G_losses'][k].append(v)
                checkpoint['losses_ts'].append(t)

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint['sample_ts'].append(t)
                checkpoint['genoutputs'].append(fake_trajs)
                checkpoint['gt'].append(real_trajs)
                checkpoint['w_loss_g'].append(w_loss_g)
                checkpoint['w_loss_d'].append(w_loss_d)

                # Check stats on the validation set
                logger.info('Checking stats on val ...')
                metrics_val = check_accuracy(
                    args, val_loader, generator, discriminator, d_loss_fn
                )
                logger.info('Checking stats on train ...')
                metrics_train = check_accuracy(
                    args, train_loader, generator, discriminator,
                    d_loss_fn, limit=True
                )
                logger.info('current generator learning rate: {}'.format(optimizer_g.state_dict()['param_groups'][0]['lr']))
                logger.info('current discriminator learning rate: {}'.format(optimizer_d.state_dict()['param_groups'][0]['lr']))

                for k, v in sorted(metrics_val.items()):
                    logger.info('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    logger.info('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    logger.info('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator.state_dict()
                    checkpoint['d_best_state'] = discriminator.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    logger.info('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator.state_dict()

                # Save another checkpoint with model weights and
                # optimizer state
                checkpoint['g_state'] = generator.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()
                checkpoint['d_state'] = discriminator.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name
                )
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)
                logger.info('Done.')

                # Save a checkpoint with no model weights by making a shallow
                # copy of the checkpoint excluding some items
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                logger.info('Saving checkpoint to {}'.format(checkpoint_path))
                key_blacklist = [
                    'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                    'g_optim_state', 'd_optim_state', 'd_best_state',
                    'd_best_nl_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
                logger.info('Done.')

            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break
Esempio n. 25
0
def main(arguments):
    # Allocate the datasets
    dataset_test = allocate_dataset_test(arguments)
    dataset_train = allocate_dataset_train(arguments)
    # Allocate the ratio estimator
    estimator = allocate_estimator(arguments)
    # Check if the gradients have to be clipped.
    if arguments.clip_grad != 0.0:
        for p in estimator.parameters():
            p.register_hook(lambda grad: torch.clamp(
                grad, -arguments.clip_grad, arguments.clip_grad))
    # Allocate the optimizer
    optimizer = torch.optim.AdamW(estimator.parameters(),
                                  amsgrad=arguments.amsgrad,
                                  lr=arguments.lr,
                                  weight_decay=arguments.weight_decay)
    # Prepare the training criterion
    if arguments.conservativeness > 0.0:
        criterion = BaseConservativeCriterion(
            batch_size=arguments.batch_size,
            beta=arguments.conservativeness,
            denominator=arguments.denominator,
            estimator=estimator,
            logits=arguments.logits)
    else:
        criterion = BaseCriterion(batch_size=arguments.batch_size,
                                  denominator=arguments.denominator,
                                  estimator=estimator,
                                  logits=arguments.logits)
    # Check if the experimental settings have to be activated
    if arguments.experimental:
        criterion = BaseExperimentalCriterion(
            batch_size=arguments.batch_size,
            denominator=arguments.denominator,
            estimator=estimator,
            logits=arguments.logits)
    # Allocate the learning rate scheduler, if requested.
    if arguments.lrsched:
        if arguments.lrsched_every is None or arguments.lrsched_gamma is None:
            lr_scheduler = ReduceLROnPlateau(optimizer, verbose=True)
        else:
            lr_scheduler = StepLR(optimizer,
                                  step_size=arguments.lrsched_every,
                                  gamma=arguments.lrsched_gamma)
    else:
        lr_scheduler = None
    # Allocate the trainer
    Trainer = create_trainer(criterion, arguments.denominator)
    trainer = Trainer(accelerator=hypothesis.accelerator,
                      batch_size=arguments.batch_size,
                      criterion=criterion,
                      dataset_test=dataset_test,
                      dataset_train=dataset_train,
                      epochs=arguments.epochs,
                      estimator=estimator,
                      lr_scheduler=lr_scheduler,
                      shuffle=(not arguments.dont_shuffle),
                      optimizer=optimizer,
                      workers=arguments.workers)
    # Register the callbacks
    if arguments.show:
        # Callbacks
        progress_bar = tqdm(total=arguments.epochs)

        def report_test_loss(caller):
            trainer = caller
            current_epoch = trainer.current_epoch
            test_loss = trainer.losses_test[-1]
            progress_bar.set_description("Test loss %s" % test_loss)
            progress_bar.update(1)

        trainer.add_event_handler(trainer.events.epoch_complete,
                                  report_test_loss)
    # Run the optimization procedure
    summary = trainer.fit()
    if arguments.show:
        # Cleanup the progress bar
        progress_bar.close()
        print(summary)
    if arguments.out is None:
        return  # No output directory has been specified, exit.
    # Create the directory if it does not exist.
    if not os.path.exists(arguments.out):
        os.mkdir(arguments.out)
    best_model_weights = summary.best_model()
    final_model_weights = summary.final_model()
    train_losses = summary.train_losses()
    test_losses = summary.test_losses()
    # Save the results.
    np.save(arguments.out + "/losses-train.npy", train_losses)
    np.save(arguments.out + "/losses-test.npy", test_losses)
    torch.save(best_model_weights, arguments.out + "/best-model.th")
    torch.save(final_model_weights, arguments.out + "/model.th")
    summary.save(arguments.out + "/result.summary")
Esempio n. 26
0
    test_loss /= 100

    print("Average Loss: ", test_loss.item())
    return test_loss


def test(x, model):
    model.eval()
    print("Prediction: ", model(torch.Tensor([x])))


loss_data = []
net = Model1()
optimizer = Adam(net.parameters(), lr=1)
scheduler = ReduceLROnPlateau(optimizer, verbose=True)

#net.load_state_dict(torch.load("model1_model.txt"))
#test(torch.Tensor([5]), net)

for epoch in progressbar.progressbar(range(5000), redirect_stdout=True):
    x = torch.randint(100, (50, 1), dtype=torch.float)
    train(x, net, optimizer)
    test_loss = run(x, net)
    #scheduler.step(test_loss)

#print(net.state_dict())

plt.plot(loss_data)
plt.show()
Esempio n. 27
0
def train_model_part(conf,
                     train_part="filterbank",
                     pretrained_filterbank=None):
    train_loader, val_loader = get_data_loaders(conf, train_part=train_part)

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(
        conf,
        model_part=train_part,
        pretrained_filterbank=pretrained_filterbank)
    # Define scheduler
    scheduler = None
    if conf[train_part + "_training"][train_part[0] + "_half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir, checkpoint_dir = get_encoded_paths(conf, train_part)
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(PairwiseNegSDR("sisdr", zero_mean=False),
                               pit_from="pw_mtx")
    system = SystemTwoStep(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
        module=train_part,
    )

    # Define callbacks
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=1,
                                 verbose=1)
    early_stopping = False
    if conf[train_part + "_training"][train_part[0] + "_early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss",
                                       patience=30,
                                       verbose=1)
    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_nb_epochs=conf[train_part + "_training"][train_part[0] +
                                                     "_epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=gpus,
        distributed_backend="dp",
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )
    trainer.fit(system)

    with open(os.path.join(checkpoint_dir, "best_k_models.json"), "w") as file:
        json.dump(checkpoint.best_k_models, file, indent=0)
Esempio n. 28
0
def trainer(model, optimizer, train_loader, test_loader, mode, epochs=5):
    tracker = trackers.PytorchHistory()
    lr_red = ReduceLROnPlateau(optimizer,
                               mode='min',
                               factor=0.5,
                               patience=30,
                               cooldown=0,
                               verbose=True,
                               threshold=1e-4,
                               min_lr=1e-8)

    for epochnum in range(epochs):
        train_loss = 0
        test_loss = 0
        train_iters = 0
        test_iters = 0
        model.train()
        if args.pb:
            gen = tqdm(enumerate(train_loader))
        else:
            gen = enumerate(train_loader)
        for i, (rnaseq, drugfeats, value) in gen:
            optimizer.zero_grad()

            if mode == 'desc' or mode == 'image' or mode == 'smiles':
                rnaseq, drugfeats, value = rnaseq.to(device), drugfeats.to(
                    device), value.to(device)
                pred = model(rnaseq, drugfeats)
            else:
                rnaseq, value = rnaseq.to(device), value.to(device)
                g = drugfeats
                h = g.ndata['atom_features'].to(device)
                pred = model(rnaseq, g, h)
            mse_loss = torch.nn.functional.mse_loss(pred, value).mean()

            if args.amp:
                with amp.scale_loss(mse_loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                mse_loss.backward()
            optimizer.step()
            train_loss += mse_loss.item()
            train_iters += 1
            tracker.track_metric(pred.detach().cpu().numpy(),
                                 value.detach().cpu().numpy())

        tracker.log_loss(train_loss / train_iters, train=True)
        tracker.log_metric(internal=True, train=True)

        model.eval()
        with torch.no_grad():
            for i, (rnaseq, drugfeats, value) in enumerate(test_loader):
                if mode == 'desc' or mode == 'image' or mode == 'smiles':
                    rnaseq, drugfeats, value = rnaseq.to(device), drugfeats.to(
                        device), value.to(device)
                    pred = model(rnaseq, drugfeats)
                else:
                    rnaseq, value = rnaseq.to(device), value.to(device)
                    g = drugfeats
                    h = g.ndata['atom_features'].to(device)
                    pred = model(rnaseq, g, h)
                mse_loss = torch.nn.functional.mse_loss(pred, value).mean()
                test_loss += mse_loss.item()
                test_iters += 1
                tracker.track_metric(pred.detach().cpu().numpy(),
                                     value.detach().cpu().numpy())
        tracker.log_loss(train_loss / train_iters, train=False)
        tracker.log_metric(internal=True, train=False)

        lr_red.step(test_loss / test_iters)
        print("Epoch", epochnum,
              train_loss / train_iters, test_loss / test_iters, 'r2',
              tracker.get_last_metric(train=True),
              tracker.get_last_metric(train=False))

    if args.g == 1:
        torch.save(
            {
                'model_state': model.state_dict(),
                'opt_state': optimizer.state_dict(),
                'inference_model': model,
                'history': tracker
            }, args.o)
    else:
        torch.save(
            {
                'model_state': model.module.state_dict(),
                'opt_state': optimizer.state_dict(),
                'inference_model': model.module,
                'history': tracker
            }, args.o)
    return model, tracker
Esempio n. 29
0
        lin3a = F.dropout(lin3, p=self.p, training=self.training)
        out = self.linear3a(lin3a)

        return out


plot_freq = 100
n_out_pixels_train = 3244800
n_out_pixels_test = 6337500
model = Regression(D_in=14400, D_out=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-3)
scheduler = ReduceLROnPlateau(optimizer,
                              mode='min',
                              factor=0.1,
                              patience=10,
                              verbose=True,
                              threshold=0.0001,
                              threshold_mode='rel',
                              cooldown=0,
                              min_lr=0,
                              eps=1e-8)
loss_c = nn.CrossEntropyLoss()
# Train the model and init variables for results
total_step = len(train_loader)
loss_list = []
acc_list = []
acc_test_list = []
loss_test_list = []
train_acc_list = []
test_acc_list = []

epochs = 200  #set the epochs here
class Model(object):

    def __init__(
            self,
            dataset,
            n_classes,
            max_n_objects,
            use_instance_segmentation=False,
            use_coords=False,
            load_model_path='',
            usegpu=True):

        self.dataset = dataset
        self.n_classes = n_classes
        self.max_n_objects = max_n_objects
        self.use_instance_segmentation = use_instance_segmentation
        self.use_coords = use_coords
        self.load_model_path = load_model_path
        self.usegpu = usegpu

        assert self.dataset in ['CVPPP', 'cityscapes']

        if self.dataset == 'CVPPP':
            self.model = CVPPPArchitecture(
                self.n_classes,
                self.use_instance_segmentation,
                self.use_coords,
                usegpu=self.usegpu)
        elif self.dataset == 'cityscapes':
            self.model = CityscapesArchitecture(
                self.n_classes,
                self.use_instance_segmentation,
                self.use_coords,
                usegpu=self.usegpu)

        self.__load_weights()

        if self.usegpu:
            cudnn.benchmark = True
            self.model.cuda()
            # self.model = torch.nn.DataParallel(self.model,
            #                                    device_ids=range(self.ngpus))

        print self.model

        self.vis = visdom.Visdom()
        self.training_metric_vis, self.test_metric_vis = None, None
        if self.use_instance_segmentation:
            self.instance_seg_vis = None

    def __load_weights(self):

        if self.load_model_path != '':
            assert os.path.isfile(self.load_model_path), 'Model : {} does not \
                exists!'.format(self.load_model_path)
            print 'Loading model from {}'.format(self.load_model_path)

            model_state_dict = self.model.state_dict()

            if self.usegpu:
                pretrained_state_dict = torch.load(self.load_model_path)
            else:
                pretrained_state_dict = torch.load(
                    self.load_model_path, map_location=lambda storage,
                    loc: storage)

            model_state_dict.update(pretrained_state_dict)
            self.model.load_state_dict(model_state_dict)

    def __define_variable(self, tensor, volatile=False):
        return Variable(tensor, volatile=volatile)

    def __define_input_variables(
            self, features, fg_labels, ins_labels, n_objects, mode):

        volatile = True
        if mode == 'training':
            volatile = False

        features_var = self.__define_variable(features, volatile=volatile)
        fg_labels_var = self.__define_variable(fg_labels, volatile=volatile)
        ins_labels_var = self.__define_variable(ins_labels, volatile=volatile)
        n_objects_var = self.__define_variable(n_objects, volatile=volatile)

        return features_var, fg_labels_var, ins_labels_var, n_objects_var

    def __define_criterion(self, class_weights, delta_var,
                           delta_dist, norm=2, optimize_bg=False,
                           criterion='CE'):
        assert criterion in ['CE', 'Dice', 'Multi', None]

        smooth = 1.0

        # Discriminative Loss
        if self.use_instance_segmentation:
            self.criterion_discriminative = DiscriminativeLoss(
                delta_var, delta_dist, norm, self.usegpu)
            if self.usegpu:
                self.criterion_discriminative = \
                    self.criterion_discriminative.cuda()

        # FG Segmentation Loss
        if class_weights is not None:
            class_weights = self.__define_variable(
                torch.FloatTensor(class_weights))
            if criterion in ['CE', 'Multi']:
                self.criterion_ce = torch.nn.CrossEntropyLoss(class_weights)
            if criterion in ['Dice', 'Multi']:
                self.criterion_dice = DiceLoss(
                    optimize_bg=optimize_bg, weight=class_weights,
                    smooth=smooth)
        else:
            if criterion in ['CE', 'Multi']:
                self.criterion_ce = torch.nn.CrossEntropyLoss()
            if criterion in ['Dice', 'Multi']:
                self.criterion_dice = DiceLoss(
                    optimize_bg=optimize_bg, smooth=smooth)

        # MSE Loss
        self.criterion_mse = torch.nn.MSELoss()

        if self.usegpu:
            if criterion in ['CE', 'Multi']:
                self.criterion_ce = self.criterion_ce.cuda()
            if criterion in ['Dice', 'Multi']:
                self.criterion_dice = self.criterion_dice.cuda()

            self.criterion_mse = self.criterion_mse.cuda()

    def __define_optimizer(self, learning_rate, weight_decay,
                           lr_drop_factor, lr_drop_patience, optimizer='Adam'):
        assert optimizer in ['RMSprop', 'Adam', 'Adadelta', 'SGD']

        parameters = ifilter(lambda p: p.requires_grad,
                             self.model.parameters())

        if optimizer == 'RMSprop':
            self.optimizer = optim.RMSprop(
                parameters, lr=learning_rate, weight_decay=weight_decay)
        elif optimizer == 'Adadelta':
            self.optimizer = optim.Adadelta(
                parameters, lr=learning_rate, weight_decay=weight_decay)
        elif optimizer == 'Adam':
            self.optimizer = optim.Adam(
                parameters, lr=learning_rate, weight_decay=weight_decay)
        elif optimizer == 'SGD':
            self.optimizer = optim.SGD(
                parameters, lr=learning_rate, momentum=0.9,
                weight_decay=weight_decay)

        self.lr_scheduler = ReduceLROnPlateau(
            self.optimizer, mode='min', factor=lr_drop_factor,
            patience=lr_drop_patience, verbose=True)

    @staticmethod
    def __get_loss_averager():
        return averager()

    def __minibatch(self, train_test_iter, clip_grad_norm,
                    criterion_type, train_cnn=True, mode='training',
                    debug=False):
        assert mode in ['training',
                        'test'], 'Mode must be either "training" or "test"'

        if mode == 'training':
            for param in self.model.parameters():
                param.requires_grad = True
            if not train_cnn:
                for param in self.model.cnn.parameters():
                    param.requires_grad = False
            self.model.train()
        else:
            for param in self.model.parameters():
                param.requires_grad = False
            self.model.eval()

        cpu_images, cpu_sem_seg_annotations, \
            cpu_ins_seg_annotations, cpu_n_objects = train_test_iter.next()
        cpu_images = cpu_images.contiguous()
        cpu_sem_seg_annotations = cpu_sem_seg_annotations.contiguous()
        cpu_ins_seg_annotations = cpu_ins_seg_annotations.contiguous()
        cpu_n_objects = cpu_n_objects.contiguous()

        if self.usegpu:
            gpu_images = cpu_images.cuda(async=True)
            gpu_sem_seg_annotations = cpu_sem_seg_annotations.cuda(async=True)
            gpu_ins_seg_annotations = cpu_ins_seg_annotations.cuda(async=True)
            gpu_n_objects = cpu_n_objects.cuda(async=True)
        else:
            gpu_images = cpu_images
            gpu_sem_seg_annotations = cpu_sem_seg_annotations
            gpu_ins_seg_annotations = cpu_ins_seg_annotations
            gpu_n_objects = cpu_n_objects

        gpu_images, gpu_sem_seg_annotations, \
            gpu_ins_seg_annotations, gpu_n_objects = \
            self.__define_input_variables(gpu_images,
                                          gpu_sem_seg_annotations,
                                          gpu_ins_seg_annotations,
                                          gpu_n_objects, mode)
        gpu_n_objects_normalized = gpu_n_objects.float() / self.max_n_objects

        sem_seg_predictions, ins_seg_predictions, \
            n_objects_predictions = self.model(gpu_images)

        if mode == 'test':
            if debug:
                _vis_prob = np.random.rand()
                if _vis_prob > 0.7:
                    if self.use_instance_segmentation:
                        sem_seg_preds = np.argmax(
                            sem_seg_predictions.data.cpu().numpy(), axis=1)
                        seg_preds = ins_seg_predictions.data.cpu().numpy()

                        _bs, _n_feats = seg_preds.shape[:2]

                        _sample_idx = np.random.randint(_bs)
                        _sem_seg_preds_sample = sem_seg_preds[_sample_idx]
                        _seg_preds_sample = seg_preds[_sample_idx]

                        fg_ins_embeddings = np.stack(
                            [_seg_preds_sample[i][np.where(
                                _sem_seg_preds_sample == 1)]
                                for i in range(_n_feats)], axis=1)
                        _n_fg_samples = fg_ins_embeddings.shape[0]
                        if _n_fg_samples > 0:
                            fg_ins_embeddings = \
                                fg_ins_embeddings[np.random.choice(
                                    range(_n_fg_samples), size=400)]

                            tsne = TSNE(n_components=2, random_state=0)
                            fg_ins_embeddings_vis = tsne.fit_transform(
                                fg_ins_embeddings)

                            if self.instance_seg_vis:
                                self.vis.scatter(X=fg_ins_embeddings_vis,
                                                 win=self.instance_seg_vis,
                                                 opts={'title':
                                                       'Predicted Embeddings \
                                                       for Foreground \
                                                       Predictions',
                                                       'markersize': 2})
                            else:
                                self.instance_seg_vis =\
                                    self.vis.scatter(X=fg_ins_embeddings_vis,
                                                     opts={'title':
                                                           'Predicted \
                                                           Embeddings for \
                                                           Foreground \
                                                           Predictions',
                                                           'markersize': 2})

        cost = 0.0
        out_metrics = dict()

        if self.use_instance_segmentation:
            disc_cost = self.criterion_discriminative(
                ins_seg_predictions, gpu_ins_seg_annotations.float(),
                cpu_n_objects, self.max_n_objects)
            cost += disc_cost
            out_metrics['Discriminative Cost'] = disc_cost.data

        if criterion_type in ['CE', 'Multi']:
            _, gpu_sem_seg_annotations_criterion_ce = \
                gpu_sem_seg_annotations.max(1)
            ce_cost = self.criterion_ce(
                sem_seg_predictions.permute(0, 2, 3, 1).contiguous().view(
                    -1, self.n_classes),
                gpu_sem_seg_annotations_criterion_ce.view(-1))
            cost += ce_cost
            out_metrics['CE Cost'] = ce_cost.data
        if criterion_type in ['Dice', 'Multi']:
            dice_cost = self.criterion_dice(
                sem_seg_predictions, gpu_sem_seg_annotations)
            cost += dice_cost
            out_metrics['Dice Cost'] = dice_cost.data

        mse_cost = self.criterion_mse(
            n_objects_predictions, gpu_n_objects_normalized)
        cost += mse_cost
        out_metrics['MSE Cost'] = mse_cost.data

        if mode == 'training':
            self.model.zero_grad()
            cost.backward()
            if clip_grad_norm != 0:
                torch.nn.utils.clip_grad_norm(
                    self.model.parameters(), clip_grad_norm)
            self.optimizer.step()

        return out_metrics

    def __test(self, test_loader, criterion_type, epoch, debug):

        n_minibatches = len(test_loader)

        test_iter = iter(test_loader)

        out_metrics = dict()
        for minibatch_index in range(n_minibatches):
            mb_out_metrics = self.__minibatch(
                test_iter, 0.0, criterion_type, train_cnn=False, mode='test',
                debug=debug)
            for mk, mv in mb_out_metrics.iteritems():
                if mk not in out_metrics:
                    out_metrics[mk] = []
                out_metrics[mk].append(mv)

        test_metric_vis_data, test_metric_vis_legend = [], []
        metrics_as_str = 'Testing:     [METRIC]'
        for mk, mv in out_metrics.iteritems():
            out_metrics[mk] = torch.stack(mv, dim=0).mean()
            metrics_as_str += ' {} : {} |'.format(mk, out_metrics[mk])

            test_metric_vis_data.append(out_metrics[mk])
            test_metric_vis_legend.append(mk)

        print metrics_as_str

        test_metric_vis_data = np.expand_dims(
            np.array(test_metric_vis_data), 0)

        if self.test_metric_vis:
            self.vis.line(X=np.array([epoch]),
                          Y=test_metric_vis_data,
                          win=self.test_metric_vis,
                          update='append')
        else:
            self.test_metric_vis = self.vis.line(X=np.array([epoch]),
                                                 Y=test_metric_vis_data,
                                                 opts={'legend':
                                                       test_metric_vis_legend,
                                                       'title': 'Test Metrics',
                                                       'showlegend': True,
                                                       'xlabel': 'Epoch',
                                                       'ylabel': 'Metric'})

        return out_metrics

    def fit(self, criterion_type, delta_var, delta_dist, norm,
            learning_rate, weight_decay, clip_grad_norm,
            lr_drop_factor, lr_drop_patience, optimize_bg, optimizer,
            train_cnn, n_epochs, class_weights, train_loader, test_loader,
            model_save_path, debug):

        assert criterion_type in ['CE', 'Dice', 'Multi']

        training_log_file = open(os.path.join(
            model_save_path, 'training.log'), 'w')
        validation_log_file = open(os.path.join(
            model_save_path, 'validation.log'), 'w')

        training_log_file.write('Epoch,Cost\n')
        validation_log_file.write('Epoch,Cost\n')

        self.__define_criterion(class_weights, delta_var, delta_dist,
                                norm=norm, optimize_bg=optimize_bg,
                                criterion=criterion_type)
        self.__define_optimizer(learning_rate, weight_decay,
                                lr_drop_factor, lr_drop_patience,
                                optimizer=optimizer)

        self.__test(test_loader, criterion_type, -1.0, debug)

        best_val_cost = np.Inf
        for epoch in range(n_epochs):
            epoch_start = time.time()

            train_iter = iter(train_loader)
            n_minibatches = len(train_loader)

            train_out_metrics = dict()

            minibatch_index = 0
            while minibatch_index < n_minibatches:
                mb_out_metrics = self.__minibatch(train_iter, clip_grad_norm,
                                                  criterion_type,
                                                  train_cnn=train_cnn,
                                                  mode='training', debug=debug)
                for mk, mv in mb_out_metrics.iteritems():
                    if mk not in train_out_metrics:
                        train_out_metrics[mk] = []
                    train_out_metrics[mk].append(mv)

                minibatch_index += 1

            epoch_end = time.time()
            epoch_duration = epoch_end - epoch_start

            training_metric_vis_data, training_metric_vis_legend = [], []

            print 'Epoch : [{}/{}] - [{}]'.format(epoch,
                                                  n_epochs, epoch_duration)
            metrics_as_str = 'Training:    [METRIC]'
            for mk, mv in train_out_metrics.iteritems():
                train_out_metrics[mk] = torch.stack(mv, dim=0).mean()
                metrics_as_str += ' {} : {} |'.format(mk,
                                                      train_out_metrics[mk])

                training_metric_vis_data.append(train_out_metrics[mk])
                training_metric_vis_legend.append(mk)

            print metrics_as_str

            training_metric_vis_data = np.expand_dims(
                np.array(training_metric_vis_data), 0)

            if self.training_metric_vis:
                self.vis.line(X=np.array([epoch]),
                              Y=training_metric_vis_data,
                              win=self.training_metric_vis, update='append')
            else:
                self.training_metric_vis = self.vis.line(
                    X=np.array([epoch]), Y=training_metric_vis_data,
                    opts={'legend': training_metric_vis_legend,
                          'title': 'Training Metrics',
                          'showlegend': True, 'xlabel': 'Epoch',
                          'ylabel': 'Metric'})

            val_out_metrics = self.__test(
                test_loader, criterion_type, epoch, debug)
            if self.use_instance_segmentation:
                val_cost = val_out_metrics['Discriminative Cost']
                train_cost = train_out_metrics['Discriminative Cost']
            elif criterion_type in ['Dice', 'Multi']:
                val_cost = val_out_metrics['Dice Cost']
                train_cost = train_out_metrics['Dice Cost']
            else:
                val_cost = val_out_metrics['CE Cost']
                train_cost = train_out_metrics['CE Cost']

            self.lr_scheduler.step(val_cost)

            is_best_model = val_cost <= best_val_cost

            if is_best_model:
                best_val_cost = val_cost
                torch.save(self.model.state_dict(), os.path.join(
                    model_save_path, 'model_{}_{}.pth'.format(epoch,
                                                              val_cost)))

            training_log_file.write('{},{}\n'.format(epoch, train_cost))
            validation_log_file.write('{},{}\n'.format(epoch, val_cost))
            training_log_file.flush()
            validation_log_file.flush()

        training_log_file.close()
        validation_log_file.close()

    def predict(self, images):

        assert len(images.size()) == 4  # b, c, h, w

        for param in self.model.parameters():
            param.requires_grad = False
        self.model.eval()

        images = images.contiguous()
        if self.usegpu:
            images = images.cuda(async=True)

        images = self.__define_variable(images, volatile=True)

        sem_seg_predictions, ins_seg_predictions, n_objects_predictions = \
            self.model(images)

        sem_seg_predictions = torch.nn.functional.softmax(
            sem_seg_predictions, dim=1)

        n_objects_predictions = n_objects_predictions * self.max_n_objects
        n_objects_predictions = torch.round(n_objects_predictions).int()

        sem_seg_predictions = sem_seg_predictions.data.cpu()
        ins_seg_predictions = ins_seg_predictions.data.cpu()
        n_objects_predictions = n_objects_predictions.data.cpu()

        return sem_seg_predictions, ins_seg_predictions, n_objects_predictions
Esempio n. 31
0
class Trainer(object):
    '''docstring for Trainer.'''

    def __init__(self, logger):
        super().__init__()
        self.logger = logger
        self.data = None
        self.device = torch.device("cuda" if torch.cuda.
                                   is_available() else "cpu")
        self.model = None
        self.optimizer = None
        self.min_lr = 0
        self.scheduler = None
        self.evaluator = None
        self.last_devloss = float('inf')
        self.models = list()

    def load_data(self, dataset, train, dev, test=None):
        assert self.data is None
        logger = self.logger
        # yapf: disable
        if dataset == Data.sigmorphon19task1:
            assert isinstance(train, list) and len(train) == 2
            self.data = dataloader.TagSIGMORPHON2019Task1(train, dev, test)
        else:
            raise ValueError
        # yapf: enable
        logger.info('src vocab size %d', self.data.source_vocab_size)
        logger.info('trg vocab size %d', self.data.target_vocab_size)
        logger.info('src vocab %r', self.data.source[:500])
        logger.info('trg vocab %r', self.data.target[:500])

    def build_model(self, opt):
        assert self.model is None
        params = dict()
        params['src_vocab_size'] = self.data.source_vocab_size
        params['trg_vocab_size'] = self.data.target_vocab_size
        params['embed_dim'] = opt.embed_dim
        params['dropout_p'] = opt.dropout
        params['src_hid_size'] = opt.src_hs
        params['trg_hid_size'] = opt.trg_hs
        params['src_nb_layers'] = opt.src_layer
        params['trg_nb_layers'] = opt.trg_layer
        params['nb_attr'] = self.data.nb_attr
        params['wid_siz'] = opt.wid_siz
        params['src_c2i'] = self.data.source_c2i
        params['trg_c2i'] = self.data.target_c2i
        params['attr_c2i'] = self.data.attr_c2i
        mono = True
        # yapf: disable
        model_classfactory = {
            (Arch.soft, not mono): model.TagTransducer,
            (Arch.hard, not mono): model.TagHardAttnTransducer,
            (Arch.hmm, mono): model.MonoTagHMMTransducer,
            (Arch.hmmfull, not mono): model.TagFullHMMTransducer,
            (Arch.hmmfull, mono): model.MonoTagFullHMMTransducer
        }
        # yapf: enable
        model_class = model_classfactory[(opt.arch, opt.mono)]
        self.model = model_class(**params)
        self.logger.info('number of attribute %d', self.model.nb_attr)
        self.logger.info('dec 1st rnn %r', self.model.dec_rnn.layers[0])
        self.logger.info('number of parameter %d',
                         self.model.count_nb_params())
        self.model = self.model.to(self.device)

    def load_model(self, model):
        assert self.model is None
        self.logger.info('load model in %s', model)
        self.model = torch.load(open(model, mode='rb'), map_location=self.device)
        self.model = self.model.to(self.device)
        epoch = int(model.split('_')[-1])
        return epoch

    def smart_load_model(self, model_prefix):
        assert self.model is None
        models = []
        for model in glob.glob(f'{model_prefix}.nll*'):
            res = re.findall(r'\w*_\d+\.?\d*', model)
            loss_, evals_, epoch_ = res[0].split('_'), res[1:-1], res[-1].split('_')
            assert loss_[0] == 'nll' and epoch_[0] == 'epoch'
            loss, epoch = float(loss_[1]), int(epoch_[1])
            evals = []
            for ev in evals_:
                ev = ev.split('_')
                evals.append(util.Eval(ev[0], ev[0], float(ev[1])))
            models.append((epoch, (model, loss, evals)))
        self.models = [x[1] for x in sorted(models)]
        return self.load_model(self.models[-1][0])

    def setup_training(self, optimizer, lr, min_lr, momentum, cooldown):
        assert self.model is not None
        if optimizer == 'SGD':
            self.optimizer = torch.optim.SGD(
                self.model.parameters(), lr, momentum=momentum)
        elif optimizer == 'Adadelta':
            self.optimizer = torch.optim.Adadelta(self.model.parameters(), lr)
        elif optimizer == 'Adam':
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr)
        else:
            raise ValueError
        self.min_lr = min_lr
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            'min',
            patience=0,
            cooldown=cooldown,
            factor=0.5,
            min_lr=min_lr)
        self.setup_evalutator()

    def save_training(self, model_fp):
        save_objs = (self.optimizer.state_dict(), self.scheduler.state_dict())
        torch.save(save_objs, open(f'{model_fp}.progress', 'wb'))

    def load_training(self, model_fp):
        assert self.model is not None
        optimizer_state, scheduler_state = torch.load(
            open(f'{model_fp}.progress', 'rb'))
        self.optimizer.load_state_dict(optimizer_state)
        self.scheduler.load_state_dict(scheduler_state)

    def setup_evalutator(self):
        self.evaluator = util.BasicEvaluator()

    def train(self, epoch_idx, batch_size, max_norm):
        logger, model, data = self.logger, self.model, self.data
        logger.info('At %d-th epoch with lr %f.', epoch_idx,
                    self.optimizer.param_groups[0]['lr'])
        model.train()
        nb_train_batch = ceil(data.nb_train / batch_size)
        for src, src_mask, trg, _ in tqdm(
                data.train_batch_sample(batch_size), total=nb_train_batch):
            out = model(src, src_mask, trg)
            loss = model.loss(out, trg[1:])
            self.optimizer.zero_grad()
            loss.backward()
            if max_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            logger.debug('loss %f with total grad norm %f', loss,
                         util.grad_norm(model.parameters()))
            self.optimizer.step()

    def iterate_batch(self, mode, batch_size):
        if mode == 'dev':
            return self.data.dev_batch_sample, ceil(
                self.data.nb_dev / batch_size)
        elif mode == 'test':
            return self.data.test_batch_sample, ceil(
                self.data.nb_test / batch_size)
        else:
            raise ValueError(f'wrong mode: {mode}')

    def calc_loss(self, mode, batch_size, epoch_idx=-1):
        self.model.eval()
        sampler, nb_batch = self.iterate_batch(mode, batch_size)
        loss, cnt = 0, 0
        for src, src_mask, trg, _ in tqdm(sampler(batch_size), total=nb_batch):
            out = self.model(src, src_mask, trg)
            loss += self.model.loss(out, trg[1:]).item()
            cnt += 1
        loss = loss / cnt
        self.logger.info(
            'Average %s loss value per instance is %f at the end of epoch %d',
            mode, loss, epoch_idx)
        return loss

    def iterate_instance(self, mode):
        if mode == 'dev':
            return self.data.dev_sample, self.data.nb_dev
        elif mode == 'test':
            return self.data.test_sample, self.data.nb_test
        else:
            raise ValueError(f'wrong mode: {mode}')

    def evaluate(self, mode, epoch_idx=-1, decode_fn=decode_greedy):
        self.model.eval()
        sampler, nb_instance = self.iterate_instance(mode)
        results = self.evaluator.evaluate_all(sampler, nb_instance, self.model,
                                              decode_fn)
        for result in results:
            self.logger.info('%s %s is %f at the end of epoch %d', mode,
                             result.long_desc, result.res, epoch_idx)
        return results

    def decode(self, mode, write_fp, decode_fn=decode_greedy):
        self.model.eval()
        cnt = 0
        sampler, nb_instance = self.iterate_instance(mode)
        with open(f'{write_fp}.{mode}.guess', 'w') as out_fp, \
             open(f'{write_fp}.{mode}.gold', 'w') as trg_fp:
            for src, trg in tqdm(sampler(), total=nb_instance):
                pred, _ = decode_fn(self.model, src)
                trg = self.data.decode_target(trg)[1:-1]
                pred = self.data.decode_target(pred)
                out_fp.write(f'{"".join(pred)}\n')
                trg_fp.write(f'{"".join(trg)}\n')
                cnt += 1
        self.logger.info(f'finished decoding {cnt} {mode} instance')

    def update_lr_and_stop_early(self, epoch_idx, devloss, estop):
        prev_lr = self.optimizer.param_groups[0]['lr']
        self.scheduler.step(devloss)
        curr_lr = self.optimizer.param_groups[0]['lr']

        stop_early = True
        if (self.last_devloss - devloss) < estop and \
            prev_lr == curr_lr == self.min_lr:
            self.logger.info(
                'Early stopping triggered with epoch %d (previous dev loss: %f, current: %f)',
                epoch_idx, self.last_devloss, devloss)
            stop_status = stop_early
        else:
            stop_status = not stop_early
        self.last_devloss = devloss
        return stop_status

    def save_model(self, epoch_idx, devloss, eval_res, model_fp):
        eval_tag = '.'.join(['{}_{}'.format(e.desc, e.res) for e in eval_res])
        fp = model_fp + '.nll_{:.4f}.{}.epoch_{}'.format(
            devloss, eval_tag, epoch_idx)
        torch.save(self.model, open(fp, 'wb'))
        self.models.append((fp, devloss, eval_res))

    def reload_and_test(self, model_fp, batch_size, best_acc):
        best_fp, _, best_res = self.models[0]
        best_acc_fp, _, best_acc = self.models[0]
        best_devloss_fp, best_devloss, _ = self.models[0]
        for fp, devloss, res in self.models:
            # [acc, edit distance ]
            if res[0].res >= best_res[0].res and res[1].res <= best_res[1].res:
                best_fp, best_res = fp, res
            if res[0].res >= best_acc[0].res:
                best_acc_fp, best_acc = fp, res
            if devloss <= best_devloss:
                best_devloss_fp, best_devloss = fp, devloss
        self.model = None
        if best_acc:
            best_fp = best_acc_fp
        self.logger.info(f'loading {best_fp} for testing')
        self.load_model(best_fp)
        self.logger.info('decoding dev set')
        self.decode(DEV, f'{model_fp}.decode')
        if self.data.test_file is not None:
            self.calc_loss(TEST, batch_size)
            self.logger.info('decoding test set')
            self.decode(TEST, f'{model_fp}.decode')
            results = self.evaluate(TEST)
            results = ' '.join([f'{r.desc} {r.res}' for r in results])
            self.logger.info(f'TEST {model_fp.split("/")[-1]} {results}')
        return set([best_fp])

    def cleanup(self, saveall, save_fps, model_fp):
        if not saveall:
            for fp, _, _ in self.models:
                if fp in save_fps:
                    continue
                os.remove(fp)
        os.remove(f'{model_fp}.progress')
def train_model(train_data,
                validate_data,
                test_data,
                lr=0.1,
                learning_model='random_walk_distribution',
                block_selection='coverage',
                n_epochs=150,
                batch_size=100,
                optimizer='adam',
                omega=4,
                training_method='surrogate-decision-focused',
                max_norm=0.1,
                block_cut_size=0.5,
                T_size=10):

    net2 = GCNPredictionNet2(feature_size)
    net2.train()

    sample_graph = train_data[0][0]
    init_T, init_s = torch.rand(sample_graph.number_of_edges(),
                                T_size), torch.zeros(
                                    sample_graph.number_of_edges())
    T, s = torch.tensor(
        normalize_matrix_positive(init_T), requires_grad=True
    ), torch.tensor(
        init_s, requires_grad=False
    )  # bias term s can cause infeasibility. It is not yet known how to resolve it.
    full_T, full_s = torch.eye(sample_graph.number_of_edges(),
                               requires_grad=False), torch.zeros(
                                   sample_graph.number_of_edges(),
                                   requires_grad=False)
    T_lr = lr

    # ================ Optimizer ================
    if optimizer == 'adam':
        optimizer = optim.Adam(net2.parameters(), lr=lr)
        T_optimizer = optim.Adam([T, s], lr=T_lr)
        # optimizer=optim.Adam(list(net2.parameters()) + [T], lr=lr)
    elif optimizer == 'sgd':
        optimizer = optim.SGD(net2.parameters(), lr=lr)
        T_optimizer = optim.SGD([T, s], lr=T_lr)
    elif optimizer == 'adamax':
        optimizer = optim.Adamax(net2.parameters(), lr=lr)
        T_optimizer = optim.Adamax([T, s], lr=T_lr)

    # scheduler = ReduceLROnPlateau(optimizer, 'min')
    scheduler = ReduceLROnPlateau(optimizer, 'min')
    T_scheduler = ReduceLROnPlateau(T_optimizer, 'min')

    training_loss_list, validating_loss_list, testing_loss_list = [], [], []
    training_defender_utility_list, validating_defender_utility_list, testing_defender_utility_list = [], [], []

    print("Training...")
    forward_time, inference_time, qp_time, backward_time = 0, 0, 0, 0

    pretrain_epochs = 0
    decay_rate = 0.95
    for epoch in range(-1, n_epochs):
        epoch_forward_time, epoch_inference_time, epoch_qp_time, epoch_backward_time = 0, 0, 0, 0
        if epoch <= pretrain_epochs:
            ts_weight = 1
            df_weight = 0
        else:
            ts_weight = decay_rate**(epoch - pretrain_epochs)
            df_weight = 1 - ts_weight

        for mode in ["training", "validating", "testing"]:
            if mode == "training":
                dataset = train_data
                epoch_loss_list = training_loss_list
                epoch_def_list = training_defender_utility_list
                if epoch > 0:
                    net2.train()
                else:
                    net2.eval()
            elif mode == "validating":
                dataset = validate_data
                epoch_loss_list = validating_loss_list
                epoch_def_list = validating_defender_utility_list
                net2.eval()
            elif mode == "testing":
                dataset = test_data
                epoch_loss_list = testing_loss_list
                epoch_def_list = testing_defender_utility_list
                net2.eval()
            else:
                raise TypeError("Not valid mode: {}".format(mode))

            loss_list, def_obj_list = [], []
            for iter_n in tqdm.trange(len(dataset)):
                forward_time_start = time.time()
                G, Fv, coverage_prob, phi_true, path_list, cut, log_prob, unbiased_probs_true, previous_gradient = dataset[
                    iter_n]
                n, m = G.number_of_nodes(), G.number_of_edges()
                budget = G.graph['budget']

                # ==================== Visualization ===================
                # if iter_n == 0 and mode == 'training':
                #     from plot_utils import plot_graph, reduce_dimension
                #     T_reduced = T.detach().numpy() # reduce_dimension(T.detach().numpy())
                #     plot_graph(G, T_reduced, epoch)

                # =============== Compute edge probabilities ===========
                Fv_torch = torch.as_tensor(Fv, dtype=torch.float)
                edge_index = torch.Tensor(list(
                    nx.DiGraph(G).edges())).long().t()
                phi_pred = net2(Fv_torch, edge_index).view(
                    -1
                ) if epoch >= 0 else phi_true  # when epoch < 0, testing the optimal loss and defender utility

                unbiased_probs_pred = phi2prob(
                    G, phi_pred) if epoch >= 0 else unbiased_probs_true
                biased_probs_pred = prob2unbiased(
                    G, -coverage_prob, unbiased_probs_pred,
                    omega=omega)  # feeding negative coverage to be biased
                single_forward_time = time.time() - forward_time_start

                # =================== Compute loss =====================
                log_prob_pred = torch.zeros(1)
                for path in path_list:
                    for e in path:
                        log_prob_pred -= torch.log(
                            biased_probs_pred[e[0]][e[1]])
                log_prob_pred /= len(path_list)
                loss = (log_prob_pred - log_prob)[0]

                # ============== COMPUTE DEFENDER UTILITY ==============
                single_data = dataset[iter_n]

                if epoch == -1:  # optimal solution
                    cut_size = m
                    def_obj, def_coverage, (
                        single_inference_time, single_qp_time
                    ) = getDefUtility(
                        single_data,
                        full_T,
                        full_s,
                        unbiased_probs_pred,
                        learning_model,
                        cut_size=cut_size,
                        omega=omega,
                        verbose=False,
                        training_mode=False,
                        training_method=training_method,
                        block_selection=block_selection)  # feed forward only

                elif mode == 'testing' or mode == "validating" or epoch <= 0:
                    cut_size = m
                    def_obj, def_coverage, (
                        single_inference_time, single_qp_time
                    ) = getDefUtility(
                        single_data,
                        T,
                        s,
                        unbiased_probs_pred,
                        learning_model,
                        cut_size=cut_size,
                        omega=omega,
                        verbose=False,
                        training_mode=False,
                        training_method=training_method,
                        block_selection=block_selection)  # feed forward only

                else:
                    if training_method == 'decision-focused' or training_method == 'surrogate-decision-focused':
                        cut_size = m
                    else:
                        raise TypeError('Not defined method')

                    def_obj, def_coverage, (
                        single_inference_time, single_qp_time) = getDefUtility(
                            single_data,
                            T,
                            s,
                            unbiased_probs_pred,
                            learning_model,
                            cut_size=cut_size,
                            omega=omega,
                            verbose=False,
                            training_mode=True,
                            training_method=training_method,
                            block_selection=block_selection
                        )  # most time-consuming part

                if epoch > 0 and mode == 'training':
                    epoch_forward_time += single_forward_time
                    epoch_inference_time += single_inference_time
                    epoch_qp_time += single_qp_time

                def_obj_list.append(def_obj.item())
                loss_list.append(loss.item())

                if (iter_n % batch_size == (batch_size - 1)) and (
                        epoch > 0) and (mode == "training"):
                    backward_start_time = time.time()
                    optimizer.zero_grad()
                    T_optimizer.zero_grad()
                    try:
                        if training_method == "decision-focused" or training_method == "surrogate-decision-focused":
                            (-def_obj).backward()
                            # (-def_obj * df_weight + loss * ts_weight).backward()
                        else:
                            raise TypeError("Not Implemented Method")
                        # torch.nn.utils.clip_grad_norm_(net2.parameters(), max_norm=max_norm) # gradient clipping

                        for parameter in net2.parameters():
                            parameter.grad = torch.clamp(parameter.grad,
                                                         min=-max_norm,
                                                         max=max_norm)
                        T.grad = torch.clamp(T.grad,
                                             min=-max_norm,
                                             max=max_norm)
                        optimizer.step()
                        T_optimizer.step()
                    except:
                        print("no grad is backpropagated...")
                    epoch_backward_time += time.time() - backward_start_time

                # ============== normalize T matrix =================
                T.data = normalize_matrix_positive(T.data)

            # ========= scheduler using validation set ==========
            if (epoch > 0) and (mode == "validating"):
                if training_method == "decision-focused" or training_method == "surrogate-decision-focused":
                    scheduler.step(-np.mean(def_obj_list))
                    T_scheduler.step(-np.mean(def_obj_list))
                else:
                    raise TypeError("Not Implemented Method")

            # ======= Storing loss and defender utility =========
            epoch_loss_list.append(np.mean(loss_list))
            epoch_def_list.append(np.mean(def_obj_list))

            # ========== Print stuff after every epoch ==========
            np.random.shuffle(dataset)
            print("Mode: {}/ Epoch number: {}/ Loss: {}/ DefU: {}".format(
                mode, epoch, np.mean(loss_list), np.mean(def_obj_list)))

        print('Forward time for this epoch: {}'.format(epoch_forward_time))
        print('QP time for this epoch: {}'.format(epoch_qp_time))
        print('Backward time for this epoch: {}'.format(epoch_backward_time))
        if epoch >= 0:
            forward_time += epoch_forward_time
            inference_time += epoch_inference_time
            qp_time += epoch_qp_time
            backward_time += epoch_backward_time

        # ============= early stopping criteria =============
        kk = 5
        if epoch >= kk * 2 - 1:
            GE_counts = np.sum(
                np.array(validating_defender_utility_list[1:][-kk:]) <=
                np.array(validating_defender_utility_list[1:][-2 * kk:-kk]) +
                1e-4)
            print(
                'Generalization error increases counts: {}'.format(GE_counts))
            if GE_counts == kk or np.sum(
                    np.isnan(
                        np.array(validating_defender_utility_list[1:]
                                 [-kk:]))) == kk:
                break

    average_nodes = np.mean([x[0].number_of_nodes() for x in train_data] +
                            [x[0].number_of_nodes() for x in validate_data] +
                            [x[0].number_of_nodes() for x in test_data])
    average_edges = np.mean([x[0].number_of_edges() for x in train_data] +
                            [x[0].number_of_edges() for x in validate_data] +
                            [x[0].number_of_edges() for x in test_data])
    print('Total forward time: {}'.format(forward_time))
    print('Total qp time: {}'.format(qp_time))
    print('Total backward time: {}'.format(backward_time))

    return net2, training_loss_list, validating_loss_list, testing_loss_list, training_defender_utility_list, validating_defender_utility_list, testing_defender_utility_list, (
        forward_time, inference_time, qp_time, backward_time), epoch
class Spatial_CNN():
    def __init__(self, nb_epochs, lr, batch_size, resume, start_epoch, evaluate, train_loader, test_loader, test_video):
        self.nb_epochs=nb_epochs
        self.lr=lr
        self.batch_size=batch_size
        self.resume=resume
        self.start_epoch=start_epoch
        self.evaluate=evaluate
        self.train_loader=train_loader
        self.test_loader=test_loader
        self.best_prec1=0
        self.test_video=test_video

    def build_model(self):
        print ('==> Build model and setup loss and optimizer')
        #build model
        self.model = resnet101(pretrained= True, channel=3).cuda()
        #Loss function and optimizer
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.optimizer = torch.optim.SGD(self.model.parameters(), self.lr, momentum=0.9)
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', patience=1,verbose=True)
    
    def resume_and_evaluate(self):
        if self.resume:
            if os.path.isfile(self.resume):
                print("==> loading checkpoint '{}'".format(self.resume))
                checkpoint = torch.load(self.resume)
                self.start_epoch = checkpoint['epoch']
                self.best_prec1 = checkpoint['best_prec1']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("==> loaded checkpoint '{}' (epoch {}) (best_prec1 {})"
                  .format(self.resume, checkpoint['epoch'], self.best_prec1))
            else:
                print("==> no checkpoint found at '{}'".format(self.resume))
        if self.evaluate:
            self.epoch = 0
            prec1, val_loss = self.validate_1epoch()
            return

    def run(self):
        self.build_model()
        self.resume_and_evaluate()
        cudnn.benchmark = True
        
        for self.epoch in range(self.start_epoch, self.nb_epochs):
            self.train_1epoch()
            prec1, val_loss = self.validate_1epoch()
            is_best = prec1 > self.best_prec1
            #lr_scheduler
            self.scheduler.step(val_loss)
            # save model
            if is_best:
                self.best_prec1 = prec1
                with open('record/spatial/spatial_video_preds.pickle','wb') as f:
                    pickle.dump(self.dic_video_level_preds,f)
                f.close()
            
            save_checkpoint({
                'epoch': self.epoch,
                'state_dict': self.model.state_dict(),
                'best_prec1': self.best_prec1,
                'optimizer' : self.optimizer.state_dict()
            },is_best,'record/spatial/checkpoint.pth.tar','record/spatial/model_best.pth.tar')

    def train_1epoch(self):
        print('==> Epoch:[{0}/{1}][training stage]'.format(self.epoch, self.nb_epochs))
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        #switch to train mode
        self.model.train()    
        end = time.time()
        # mini-batch training
        progress = tqdm(self.train_loader)
        for i, (data_dict,label) in enumerate(progress):

    
            # measure data loading time
            data_time.update(time.time() - end)
            
            label = label.cuda(async=True)
            target_var = Variable(label).cuda()

            # compute output
            output = Variable(torch.zeros(len(data_dict['img1']),101).float()).cuda()
            for i in range(len(data_dict)):
                key = 'img'+str(i)
                data = data_dict[key]
                input_var = Variable(data).cuda()
                output += self.model(input_var)

            loss = self.criterion(output, target_var)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, label, topk=(1, 5))
            losses.update(loss.data[0], data.size(0))
            top1.update(prec1[0], data.size(0))
            top5.update(prec5[0], data.size(0))

            # compute gradient and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
        
        info = {'Epoch':[self.epoch],
                'Batch Time':[round(batch_time.avg,3)],
                'Data Time':[round(data_time.avg,3)],
                'Loss':[round(losses.avg,5)],
                'Prec@1':[round(top1.avg,4)],
                'Prec@5':[round(top5.avg,4)],
                'lr': self.optimizer.param_groups[0]['lr']
                }
        record_info(info, 'record/spatial/rgb_train.csv','train')

    def validate_1epoch(self):
        print('==> Epoch:[{0}/{1}][validation stage]'.format(self.epoch, self.nb_epochs))
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        # switch to evaluate mode
        self.model.eval()
        self.dic_video_level_preds={}
        end = time.time()
        progress = tqdm(self.test_loader)
        for i, (keys,data,label) in enumerate(progress):
            
            label = label.cuda(async=True)
            data_var = Variable(data, volatile=True).cuda(async=True)
            label_var = Variable(label, volatile=True).cuda(async=True)

            # compute output
            output = self.model(data_var)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            #Calculate video level prediction
            preds = output.data.cpu().numpy()
            nb_data = preds.shape[0]
            for j in range(nb_data):
                videoName = keys[j].split('/',1)[0]
                if videoName not in self.dic_video_level_preds.keys():
                    self.dic_video_level_preds[videoName] = preds[j,:]
                else:
                    self.dic_video_level_preds[videoName] += preds[j,:]

        video_top1, video_top5, video_loss = self.frame2_video_level_accuracy()
            

        info = {'Epoch':[self.epoch],
                'Batch Time':[round(batch_time.avg,3)],
                'Loss':[round(video_loss,5)],
                'Prec@1':[round(video_top1,3)],
                'Prec@5':[round(video_top5,3)]}
        record_info(info, 'record/spatial/rgb_test.csv','test')
        return video_top1, video_loss

    def frame2_video_level_accuracy(self):
            
        correct = 0
        video_level_preds = np.zeros((len(self.dic_video_level_preds),101))
        video_level_labels = np.zeros(len(self.dic_video_level_preds))
        ii=0
        for name in sorted(self.dic_video_level_preds.keys()):
        
            preds = self.dic_video_level_preds[name]
            label = int(self.test_video[name])-1
                
            video_level_preds[ii,:] = preds
            video_level_labels[ii] = label
            ii+=1         
            if np.argmax(preds) == (label):
                correct+=1

        #top1 top5
        video_level_labels = torch.from_numpy(video_level_labels).long()
        video_level_preds = torch.from_numpy(video_level_preds).float()
            
        top1,top5 = accuracy(video_level_preds, video_level_labels, topk=(1,5))
        loss = self.criterion(Variable(video_level_preds).cuda(), Variable(video_level_labels).cuda())     
                            
        top1 = float(top1.numpy())
        top5 = float(top5.numpy())
            
        #print(' * Video level Prec@1 {top1:.3f}, Video level Prec@5 {top5:.3f}'.format(top1=top1, top5=top5))
        return top1,top5,loss.data.cpu().numpy()
Esempio n. 34
0
                          shuffle=True,
                          num_workers=num_workers)
valid_loader = DataLoader(valid_dataset,
                          batch_size=bs,
                          shuffle=False,
                          num_workers=num_workers)

loaders = {"train": train_loader, "valid": valid_loader}

num_epochs = 10
log_interval = 5
logdir = "./logs/segmentation"

# model, criterion, optimizer
optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-4)
scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=2)
criterion1 = torch.nn.MSELoss()

# -- TRAINING --


def plot(data, gt, out_a, out_b, gs, gt_gs):
    gs = gs[0]
    gt_gs = np.array([[np.array(gt_gs[0]) for i in range(44)]
                      for j in range(44)])
    d = to_np_img(data[0])
    d = cv2.split(d)
    d = np.dstack((d[0], d[1]))
    d = transform_from_log(d, gs)
    gt_img = np.array([[np.array(gt[0].cpu()) for i in range(44)]
                       for j in range(44)])
def main(args):

    args.cuda = args.use_cuda and torch.cuda.is_available()
    train_set, validate_set, test_set, train_loader, validate_loader, test_loader = get_data.get_data_headpose(
        args)
    model = models.Gaze(args)
    # TODO: try to use the step policy for Adam, also consider the step interval
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    # TODO: related to line 99 (e.g. max for auc, min for loss; try to check the definition of this method)
    scheduler = ReduceLROnPlateau(optimizer,
                                  factor=args.lr_decay,
                                  patience=1,
                                  verbose=True,
                                  mode='min')
    # TODO; double check this loss, also the output of the network
    criterion = torch.nn.MSELoss()

    #------------------------
    # use multi-gpu

    if args.cuda and torch.cuda.device_count() > 1:
        print("Now Using ", len(args.device_ids), " GPUs!")

        #model=model.to(device_ids[0])
        model = torch.nn.DataParallel(model,
                                      device_ids=args.device_ids,
                                      output_device=args.device_ids[0]).cuda()
        criterion = criterion.cuda()

    elif args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    if args.load_best_checkpoint:
        loaded_checkpoint = utils.load_best_checkpoint(args,
                                                       model,
                                                       optimizer,
                                                       path=args.resume)

        if loaded_checkpoint:
            args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

    if args.load_last_checkpoint:
        loaded_checkpoint = utils.load_last_checkpoint(
            args,
            model,
            optimizer,
            path=args.resume,
            version=args.model_load_version)

        if loaded_checkpoint:
            args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

    #------------------------------------------------------------------------------
    # Train

    since = time.time()

    train_epoch_loss_all = []
    val_epoch_loss_all = []

    best_loss = np.inf
    avg_epoch_loss = np.inf

    for epoch in range(args.start_epoch, args.epochs):

        train_epoch_loss = train(train_loader, model, criterion, optimizer,
                                 epoch, args)
        train_epoch_loss_all.append(train_epoch_loss)
        #visdom_viz(vis, train_epoch_loss_all, win=0, ylabel='Training Epoch Loss', title=args.project_name, color='green')

        val_epoch_loss = validate(validate_loader, model, criterion, epoch,
                                  args)
        val_epoch_loss_all.append(val_epoch_loss)
        #visdom_viz(vis, val_epoch_loss_all, win=1, ylabel='Validation Epoch Loss', title=args.project_name,color='blue')

        print(
            'Epoch {}/{} Training Loss: {:.4f} Validation Loss: {:.4f}'.format(
                epoch, args.epochs - 1, train_epoch_loss, val_epoch_loss))
        print('*' * 15)

        #TODO: reducing lr when there is no gains on validation metric results (e.g. auc, loss)
        scheduler.step(val_epoch_loss)

        is_best = val_epoch_loss < best_loss

        if is_best:
            best_loss = val_epoch_loss

        avg_epoch_loss = np.mean(val_epoch_loss_all)

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_epoch_error': best_loss,
                'avg_epoch_error': avg_epoch_loss,
                'optimizer': optimizer.state_dict(),
                'args': args
            },
            is_best=is_best,
            directory=args.resume,
            version='epoch_{}'.format(str(epoch)))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Val Loss: {},  Final Avg Val Loss: {}'.format(
        best_loss, avg_epoch_loss))

    #-------------------------------------------------------------------------------------------------------------
    # test
    loaded_checkpoint = utils.load_best_checkpoint(args,
                                                   model,
                                                   optimizer,
                                                   path=args.resume)
    if loaded_checkpoint:
        args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

    pred_rpy, gt_rpy, test_loss = test(test_loader, model, criterion, args)

    print("Test Epoch Loss {}".format(test_loss))

    # save test results
    if not isdir(args.save_test_res):
        os.mkdir(args.save_test_res)

    with open(os.path.join(args.save_test_res, 'raw_test_results.pkl'),
              'w') as f:
        pickle.dump([pred_rpy, gt_rpy, test_loss], f)
OPTIM = Adam(params=[
    {
        "params": MODEL.features.parameters(),
        'lr': 0.0001
    },
    {
        "params": MODEL.classifier.parameters(),
        'lr': 0.001
    },
], )

CRITERION = FocalLoss(gamma=2.0)

LR_SCHEDULERS = [
    MultiStepLR(OPTIM, milestones=[5, 7, 9, 10, 11, 12, 13], gamma=0.5)
]

REDUCE_LR_ON_PLATEAU = ReduceLROnPlateau(OPTIM,
                                         mode='min',
                                         factor=0.5,
                                         patience=5,
                                         threshold=0.05,
                                         verbose=True)

EARLY_STOPPING_KWARGS = {
    'patience': 30,
    # 'score_function': None
}

LOG_INTERVAL = 100
def main(train_args):
    net = FCN8s(num_classes=voc.num_classes, caffe=True).cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels()
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train', transform=input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=1, num_workers=4, shuffle=True)
    val_set = voc.VOC('val', transform=input_transform, target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=1, num_workers=4, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False, ignore_index=voc.ignore_label).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=train_args['momentum'])

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, train_args, restore_transform, visualize)
        scheduler.step(val_loss)
Esempio n. 38
0
def training_main(model_name: str, train_config: Dict[str, Any],
                  model_config: Dict[str, int]):

    # create train and dev datasets using the files specified in the training configuration
    train_samples_dir = train_config["train_sample_dir"]
    train_labels_dir = train_config["train_labels_dir"]
    train_containment_file = train_config["train_containment_file"]

    dev_samples_dir = train_config["dev_sample_dir"]
    dev_labels_dir = train_config["dev_labels_dir"]
    dev_containment_file = train_config["dev_containment_file"]

    train_dataset: data.Dataset = DatasetsFactory.get_training_dataset(
        model_name, train_samples_dir, train_labels_dir,
        train_containment_file)
    dev_dataset: data.Dataset = DatasetsFactory.get_training_dataset(
        model_name, dev_samples_dir, dev_labels_dir, dev_containment_file)

    # training hyper parameters and configuration
    batch_size = train_config["batch_size"]
    num_workers = train_config["num_workers"]
    num_epochs = train_config["num_epochs"]
    learning_rate = train_config["learning_rate"]
    print_batch_step = train_config["print_step"]
    inference_batch_size = train_config["inference_batch_size"]
    scheduler_patience = train_config["lr_scheduler_patience"]
    scheduler_factor = train_config["lr_scheduler_factor"]
    checkpoints_path = train_config["checkpoints_path"]
    device = torch.device(train_config["device"])
    # consistency_rate = train_config["consistency_rate"]

    # model, loss and optimizer
    model: nn.Module = ModelsFactory.get_model(model_name, model_config)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  factor=scheduler_factor,
                                  patience=scheduler_patience,
                                  verbose=True)
    loss_function = nn.L1Loss(reduction="none")

    # create data loaders
    train_config_dict = {"batch_size": batch_size, "num_workers": num_workers}
    inference_config_dict = {
        "batch_size": inference_batch_size,
        "num_workers": num_workers
    }
    training_loader = data.DataLoader(train_dataset, **train_config_dict)
    train_inference_loader = data.DataLoader(train_dataset,
                                             **inference_config_dict)
    dev_loader = data.DataLoader(dev_dataset, **inference_config_dict)

    # Start training
    model = model.to(device)
    highest_dev_iou: float = 0
    train_start_time = time.time()

    for epoch in range(num_epochs):
        model.train(mode=True)
        epoch_num = epoch + 1

        # loss statistics
        batches_running_loss = 0
        batches_running_pred_loss = 0
        batches_running_const_loss = 0

        for batch_idx, sample in enumerate(training_loader, 1):

            x, y, _ = sample
            boxes, _ = x
            labels, mask = y
            boxes = boxes.to(device)
            labels, mask = labels.to(device), mask.to(device)

            optimizer.zero_grad()

            if model_name in DOUBLE_OUTPUT_MODELS:
                output, index_to_track_prediction = model(boxes)

            else:
                output = model(boxes)

            # prediction loss
            pred_loss = loss_function(output, labels)

            # consistency loss
            next_output_frames = output[:, 1:, :]
            current_output_frames = output[:, :-1, :]
            consistency_loss = torch.mean(
                torch.norm(next_output_frames - current_output_frames,
                           p=2,
                           dim=-1))

            if model_name in NO_LABELS_MODELS:
                pred_loss = pred_loss * mask  # mask contains only visible objects
                pred_loss = torch.mean(pred_loss)

            else:
                pred_loss = torch.mean(pred_loss)

            if model_name in NO_LABELS_MODELS:
                loss = pred_loss + 0.5 * consistency_loss

            else:
                loss = pred_loss

            batches_running_loss += loss.item()
            batches_running_pred_loss += pred_loss.item()
            batches_running_const_loss += consistency_loss.item()

            loss.backward()
            optimizer.step()

            # print inter epoch statistics
            if batch_idx % print_batch_step == 0:

                num_samples_seen = batch_idx * batch_size
                num_samples_total = len(train_dataset)
                epoch_complete_ratio = 100 * batch_idx / len(training_loader)
                average_running_loss = batches_running_loss / print_batch_step
                average_pred_loss = batches_running_pred_loss / print_batch_step
                average_consist_loss = batches_running_const_loss / print_batch_step
                time_since_beginning = int(time.time() - train_start_time)

                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\t Average Loss: Total {:.4f}, Pred {:.4f} Consistent {:.4f} Training began {} seconds ago"
                    .format(epoch_num, num_samples_seen, num_samples_total,
                            epoch_complete_ratio, average_running_loss,
                            average_pred_loss, average_consist_loss,
                            time_since_beginning))

                batches_running_loss = 0
                batches_running_pred_loss = 0
                batches_running_const_loss = 0

        # end of epoch - compute mean iou over train and dev
        train_loss, train_miou, train_containment_miou = inference_and_iou_comp(
            model_name, model, device, train_inference_loader,
            len(train_dataset), loss_function)
        dev_loss, dev_miou, dev_containment_miou = inference_and_iou_comp(
            model_name, model, device, dev_loader, len(dev_dataset),
            loss_function)

        print(
            "Epoch {} Training Set: Loss {:.4f}, Mean IoU {:.6f}, Mask Mean Iou {:.6f}"
            .format(epoch_num, train_loss, train_miou, train_containment_miou))
        print(
            "Epoch {} Dev Set: Loss {:.4f}, Mean IoU {:.6f}, Mask Mean Iou {:.6f}"
            .format(epoch_num, dev_loss, dev_miou, dev_containment_miou))

        # learning rate scheduling
        scheduler.step(train_loss)

        # check if it is the best performing model so far and save it
        if dev_miou > highest_dev_iou:
            highest_dev_iou = dev_miou
            save_checkpoint(model, model_name, round(highest_dev_iou, 3),
                            checkpoints_path)
Esempio n. 39
0
def train_model(train_data,
                validate_data,
                test_data,
                lr=0.1,
                learning_model='random_walk_distribution',
                block_selection='coverage',
                n_epochs=150,
                batch_size=100,
                optimizer='adam',
                omega=4,
                training_method='two-stage',
                max_norm=0.1,
                block_cut_size=0.5):

    net2 = GCNPredictionNet2(feature_size)
    net2.train()
    if optimizer == 'adam':
        optimizer = optim.Adam(net2.parameters(), lr=lr)
    elif optimizer == 'sgd':
        optimizer = optim.SGD(net2.parameters(), lr=lr)
    elif optimizer == 'adamax':
        optimizer = optim.Adamax(net2.parameters(), lr=lr)

    # scheduler = ReduceLROnPlateau(optimizer, 'min')
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5)

    training_loss_list, validating_loss_list, testing_loss_list = [], [], []
    training_defender_utility_list, validating_defender_utility_list, testing_defender_utility_list = [], [], []

    print("Training...")
    forward_time, inference_time, qp_time, backward_time = 0, 0, 0, 0

    evaluate = False if training_method == 'two-stage' else True  # for two-stage only
    pretrain_epochs = 0
    decay_rate = 0.95
    for epoch in range(-1, n_epochs):
        if epoch == n_epochs - 1:
            evaluate = True

        epoch_forward_time, epoch_qp_time, epoch_backward_time, epoch_inference_time = 0, 0, 0, 0
        if epoch <= pretrain_epochs:
            ts_weight = 1
            df_weight = 0
        else:
            ts_weight = decay_rate**(epoch - pretrain_epochs)
            df_weight = 1 - ts_weight

        for mode in ["training", "validating", "testing"]:
            if mode == "training":
                dataset = train_data
                epoch_loss_list = training_loss_list
                epoch_def_list = training_defender_utility_list
                if epoch > 0:
                    net2.train()
                else:
                    net2.eval()
            elif mode == "validating":
                dataset = validate_data
                epoch_loss_list = validating_loss_list
                epoch_def_list = validating_defender_utility_list
                net2.eval()
            elif mode == "testing":
                dataset = test_data
                epoch_loss_list = testing_loss_list
                epoch_def_list = testing_defender_utility_list
                net2.eval()
            else:
                raise TypeError("Not valid mode: {}".format(mode))

            loss_list, def_obj_list = [], []
            for iter_n in tqdm.trange(len(dataset)):
                forward_start_time = time.time()
                G, Fv, coverage_prob, phi_true, path_list, cut, log_prob, unbiased_probs_true, previous_gradient = dataset[
                    iter_n]
                n, m = G.number_of_nodes(), G.number_of_edges()

                # =============== Compute edge probabilities ===========
                Fv_torch = torch.as_tensor(Fv, dtype=torch.float)
                edge_index = torch.Tensor(list(
                    nx.DiGraph(G).edges())).long().t()
                phi_pred = net2(Fv_torch, edge_index).view(
                    -1
                ) if epoch >= 0 else phi_true  # when epoch < 0, testing the optimal loss and defender utility

                unbiased_probs_pred = phi2prob(
                    G, phi_pred) if epoch >= 0 else unbiased_probs_true
                biased_probs_pred = prob2unbiased(
                    G, -coverage_prob, unbiased_probs_pred,
                    omega=omega)  # feeding negative coverage to be biased

                # =================== Compute loss =====================
                log_prob_pred = torch.zeros(1)
                for path in path_list:
                    for e in path:
                        log_prob_pred -= torch.log(
                            biased_probs_pred[e[0]][e[1]])
                log_prob_pred /= len(path_list)
                loss = (log_prob_pred - log_prob)[0]

                single_forward_time = time.time() - forward_start_time
                single_qp_time = 0

                # ============== COMPUTE DEFENDER UTILITY ==============
                single_data = dataset[iter_n]

                if mode == 'testing' or mode == "validating" or epoch <= 0:  # or training_method == "two-stage" or epoch <= 0:
                    cut_size = m
                    if evaluate or epoch <= 0:
                        def_obj, def_coverage, (
                            single_inference_time,
                            single_qp_time) = getDefUtility(
                                single_data,
                                unbiased_probs_pred,
                                learning_model,
                                cut_size=cut_size,
                                omega=omega,
                                verbose=False,
                                training_mode=False,
                                training_method=training_method,
                                block_selection=block_selection
                            )  # feed forward only
                    else:
                        def_obj, def_coverage = torch.Tensor([-float('Inf')
                                                              ]), None
                else:
                    if training_method == "two-stage" or epoch <= pretrain_epochs:
                        cut_size = m
                        if evaluate:
                            def_obj, def_coverage, (
                                single_inference_time,
                                single_qp_time) = getDefUtility(
                                    single_data,
                                    unbiased_probs_pred,
                                    learning_model,
                                    cut_size=cut_size,
                                    omega=omega,
                                    verbose=False,
                                    training_mode=False,
                                    training_method=training_method,
                                    block_selection=block_selection
                                )  # most time-consuming part
                        else:
                            def_obj, def_coverage = torch.Tensor(
                                [-float('Inf')]), None
                            single_inference_time, single_qp_time = 0, 0
                            # ignore the time of computing defender utility
                    else:
                        if training_method == 'decision-focused':
                            cut_size = m
                        elif training_method == 'block-decision-focused' or training_method == 'hybrid' or training_method == 'corrected-block-decision-focused':
                            if type(block_cut_size
                                    ) == str and block_cut_size[-1] == 'n':
                                cut_size = int(n * float(block_cut_size[:-1]))
                            elif block_cut_size <= 1:
                                cut_size = int(m * block_cut_size)
                            else:
                                cut_size = block_cut_size
                        else:
                            raise TypeError('Not defined method')

                        def_obj, def_coverage, (
                            single_inference_time,
                            single_qp_time) = getDefUtility(
                                single_data,
                                unbiased_probs_pred,
                                learning_model,
                                cut_size=cut_size,
                                omega=omega,
                                verbose=False,
                                training_mode=True,
                                training_method=training_method,
                                block_selection=block_selection
                            )  # most time-consuming part

                if epoch > 0 and mode == "training":
                    epoch_forward_time += single_forward_time
                    epoch_inference_time += single_inference_time
                    epoch_qp_time += single_qp_time

                def_obj_list.append(def_obj.item())
                loss_list.append(loss.item())

                if (iter_n % batch_size == (batch_size - 1)) and (
                        epoch > 0) and (mode == "training"):
                    backward_start_time = time.time()
                    optimizer.zero_grad()
                    try:
                        if training_method == "two-stage" or epoch <= pretrain_epochs:
                            loss.backward()
                        elif training_method == "decision-focused" or training_method == "block-decision-focused" or training_method == 'corrected-block-decision-focused':
                            (-def_obj * m / cut_size).backward()
                        elif training_method == "hybrid":
                            ((-def_obj * m / cut_size) * df_weight +
                             loss * ts_weight).backward()
                        else:
                            raise TypeError("Not Implemented Method")
                        torch.nn.utils.clip_grad_norm_(
                            net2.parameters(),
                            max_norm=max_norm)  # gradient clipping
                        optimizer.step()
                    except:
                        print("no grad is backpropagated...")
                    epoch_backward_time += time.time() - backward_start_time

            if (epoch > 0) and (mode == "validating"):
                if training_method == "two-stage":
                    scheduler.step(np.mean(loss_list))
                elif training_method == "decision-focused" or training_method == "block-decision-focused" or training_method == 'corrected-block-decision-focused' or training_method == 'hybrid':
                    scheduler.step(-np.mean(def_obj_list))
                else:
                    raise TypeError("Not Implemented Method")

            # Storing loss and defender utility
            epoch_loss_list.append(np.mean(loss_list))
            epoch_def_list.append(np.mean(def_obj_list))

            ################################### Print stuff after every epoch
            np.random.shuffle(dataset)
            print("Mode: {}/ Epoch number: {}/ Loss: {}/ DefU: {}".format(
                mode, epoch, np.mean(loss_list), np.mean(def_obj_list)))

        print('Forward time for this epoch: {}'.format(epoch_forward_time))
        print('QP time for this epoch: {}'.format(epoch_qp_time))
        print('Backward time for this epoch: {}'.format(epoch_backward_time))
        if epoch >= 0:
            forward_time += epoch_forward_time
            inference_time += epoch_inference_time
            qp_time += epoch_qp_time
            backward_time += epoch_backward_time

        # ============= early stopping criteria =============
        kk = 3
        if epoch >= kk * 2 - 1:
            if training_method == 'two-stage':
                if evaluate:
                    break
                GE_counts = np.sum(
                    np.array(validating_loss_list[1:][-kk:]) >=
                    np.array(validating_loss_list[1:][-2 * kk:-kk]) - 1e-4)
                print('Generalization error increases counts: {}'.format(
                    GE_counts))
                if GE_counts == kk:
                    evaluate = True
            else:  # surrogate or decision-focused
                GE_counts = np.sum(
                    np.array(validating_defender_utility_list[1:][-kk:]) <= np.
                    array(validating_defender_utility_list[1:][-2 * kk:-kk]) +
                    1e-4)
                print('Generalization error increases counts: {}'.format(
                    GE_counts))
                if GE_counts == kk:
                    break

    average_nodes = np.mean([x[0].number_of_nodes() for x in train_data] +
                            [x[0].number_of_nodes() for x in validate_data] +
                            [x[0].number_of_nodes() for x in test_data])
    average_edges = np.mean([x[0].number_of_edges() for x in train_data] +
                            [x[0].number_of_edges() for x in validate_data] +
                            [x[0].number_of_edges() for x in test_data])
    print('Total forward time: {}'.format(forward_time))
    print('Total inference time: {}'.format(inference_time))
    print('Total qp time: {}'.format(qp_time))
    print('Total backward time: {}'.format(backward_time))

    return net2, training_loss_list, validating_loss_list, testing_loss_list, training_defender_utility_list, validating_defender_utility_list, testing_defender_utility_list, (
        forward_time, inference_time, qp_time, backward_time), epoch
Esempio n. 40
0
def main(args):

    # Setting up seeds.
    torch.cuda.manual_seed(args.seed)
    torch.manual_seed(args.seed)

    # Create model directory.
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    # Config logging.
    log_format = '%(levelname)-8s %(message)s'
    logfile = os.path.join(args.model_dir, 'train.log')
    logging.basicConfig(filename=logfile,
                        level=logging.INFO,
                        format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(json.dumps(args.__dict__))

    # Save the arguments.
    with open(os.path.join(args.model_dir, 'args.json'), 'w') as args_file:
        json.dump(args.__dict__, args_file)

    # Load vocabulary wrapper.
    vocab = load_vocab(args.vocab_path)
    vocab.top_answers = json.load(open(args.top_answers))

    # Build data loader.
    logging.info("Building data loader...")
    data_loader = get_vqa_loader(args.dataset,
                                 args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 max_examples=args.max_examples)

    val_data_loader = get_vqa_loader(args.val_dataset,
                                     args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)
    logging.info("Done")

    # Build the models
    logging.info("Building MultiSAVQA models...")
    vqa = MultiSAVQAModel(len(vocab),
                          args.max_length,
                          args.hidden_size,
                          args.vocab_embed_size,
                          num_layers=args.num_layers,
                          rnn_cell=args.rnn_cell,
                          bidirectional=args.bidirectional,
                          input_dropout_p=args.dropout,
                          dropout_p=args.dropout,
                          num_att_layers=args.num_att_layers,
                          att_ff_size=args.att_ff_size)
    logging.info("Done")

    if torch.cuda.is_available():
        vqa.cuda()

    # Loss and Optimizer.
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        criterion.cuda()

    # Parameters to train.
    params = vqa.params_to_train()
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                  mode='min',
                                  factor=0.1,
                                  patience=args.patience,
                                  verbose=True,
                                  min_lr=1e-6)

    # Train the Models.
    total_steps = len(data_loader) * args.num_epochs
    start_time = time.time()
    n_steps = 0
    for epoch in range(args.num_epochs):
        for i, (feats, questions, categories) in enumerate(data_loader):
            n_steps += 1

            # Set mini-batch dataset.
            if torch.cuda.is_available():
                feats = feats.cuda()
                questions = questions.cuda()
                categories = categories.cuda()
            qlengths = process_lengths(questions)

            # Forward.
            vqa.train()
            vqa.zero_grad()
            outputs = vqa(feats, questions, qlengths)

            # Calculate the loss.
            loss = criterion(outputs, categories)

            # Backprop and optimize.
            loss.backward()
            optimizer.step()

            # Eval now.
            if (args.eval_every_n_steps is not None
                    and n_steps >= args.eval_every_n_steps
                    and n_steps % args.eval_every_n_steps == 0):
                logging.info('=' * 100)
                val_loss = evaluate(vocab, vqa, val_data_loader, criterion,
                                    epoch, args)
                scheduler.step(val_loss)
                logging.info('=' * 100)

            # Take argmax for each timestep
            preds = outputs.max(1)[1]
            score = accuracy(preds, categories)

            # Print log info.
            if i % args.log_step == 0:
                delta_time = time.time() - start_time
                start_time = time.time()
                logging.info('Time: %.4f, Epoch [%d/%d], Step [%d/%d], '
                             'Accuracy: %.4f, Loss: %.4f, LR: %f' %
                             (delta_time, epoch + 1,
                              args.num_epochs, n_steps, total_steps, score,
                              loss.item(), optimizer.param_groups[0]['lr']))

            # Save the models.
            if (i + 1) % args.save_step == 0:
                torch.save(
                    vqa.state_dict(),
                    os.path.join(args.model_dir,
                                 'multi-savqa-%d-%d.pkl' % (epoch + 1, i + 1)))

        torch.save(
            vqa.state_dict(),
            os.path.join(args.model_dir, 'multi-savqa-%d.pkl' % (epoch + 1)))

        # Evaluation and learning rate updates.
        logging.info('=' * 100)
        val_loss = evaluate(vocab, vqa, val_data_loader, criterion, epoch,
                            args)
        scheduler.step(val_loss)
        logging.info('=' * 100)

    # Save the final model.
    torch.save(vqa.state_dict(), os.path.join(args.model_dir, 'vqa.pkl'))