예제 #1
0
    def __init__(self, src_vocab, tgt_vocab,
                 max_len=300, hidden_size=300, n_layers=2, clip=5, n_epochs=30):
        # hyper-parameters
        self.max_len = max_len
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.clip = clip
        self.n_epochs = n_epochs

        # vocab
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.pad_idx = self.src_vocab.stoi[PAD]

        # prepare model
        self.encoder = Encoder(self.src_vocab, self.max_len, self.hidden_size, self.n_layers)
        self.decoder = Decoder(self.tgt_vocab, self.max_len, self.hidden_size * 2, self.n_layers)
        self.reverse_decoder = Decoder(self.tgt_vocab, self.max_len, self.hidden_size * 2, self.n_layers, reverse=True)
        self.model = Seq2SeqConcat(self.encoder, self.decoder, self.reverse_decoder, self.pad_idx)
        self.model.to(device)
        print(self.model)
        print("Total parameters:", sum([p.nelement() for p in self.model.parameters()]))

        # initialize weights
        for name, param in self.model.named_parameters():
            if "lstm.bias" in name:
                # set lstm forget gate to 1 (Jozefowicz et al., 2015)
                n = param.size(0)
                param.data[n//4:n//2].fill_(1.0)
            elif "lstm.weight" in name:
                nn.init.xavier_uniform_(param)

        # prepare loss function; don't calculate loss on PAD tokens
        self.criterion = nn.NLLLoss(ignore_index=self.pad_idx)

        # prepare optimizer and scheduler
        self.optimizer = Adam(self.model.parameters())
        self.scheduler = CyclicLR(self.optimizer, base_lr=0.00001, max_lr=0.00005,
                                  step_size_up=4000, step_size_down=4000,
                                  mode="triangular", gamma=1.0, cycle_momentum=False)

        # book keeping vars
        self.global_iter = 0
        self.global_numel = []
        self.global_loss = []
        self.global_acc = []

        # visualization
        self.vis_loss = Visualization(env_name="aivivn_tone", xlabel="step", ylabel="loss", title="loss (mean per 300 steps)")
        self.vis_acc = Visualization(env_name="aivivn_tone", xlabel="step", ylabel="acc", title="training accuracy (mean per 300 steps)")
예제 #2
0
def main(_):
    tf.gfile.MakeDirs(FLAGS.output_dir)

    if FLAGS.is_fixed_emb:
        emb_matrix = utils.get_emb_matrix(FLAGS.data_dir, FLAGS.max_features)

    clr = CyclicLR(base_lr=FLAGS.min_lr,
                   max_lr=FLAGS.max_lr,
                   step_size=2740,
                   mode='exp_range',
                   gamma=0.99994)
    matcher = TextMatcher(FLAGS.model_name, FLAGS.vocab_file,
                          FLAGS.do_lower_case, FLAGS.max_seq_len)

    model_builder = ModelBuilder(model_name=FLAGS.model_name,
                                 max_len=FLAGS.max_seq_len,
                                 input_dim=FLAGS.input_dim,
                                 max_features=FLAGS.max_features,
                                 units=FLAGS.units,
                                 num_filter=FLAGS.num_filter)

    if FLAGS.is_fixed_emb:
        model_builder.set_embedding_matrix(emb_matrix)

    model = model_builder.build_model()

    print(model.summary())

    if FLAGS.do_train:
        train_example = matcher.get_train_examples(FLAGS.data_dir)
        matcher.do_train(model,
                         FLAGS.output_dir,
                         train_example,
                         FLAGS.epochs,
                         FLAGS.batch_size,
                         callback=[
                             clr,
                         ])

    if FLAGS.do_eval:
        dev_example = matcher.get_dev_examples(FLAGS.data_dir)
        matcher.do_eval(model, FLAGS.output_dir, dev_example, FLAGS.batch_size)

    if FLAGS.do_predict:
        test_example = matcher.get_test_examples(FLAGS.data_dir)
        matcher.do_predict(model, FLAGS.output_dir, test_example,
                           FLAGS.batch_size)
예제 #3
0
def main(model, batch_size, n_epochs, lr, train_fpath, val_fpath, train_preprocess, val_preprocess, multimask, patience, _run):
    run_validation = val_fpath is not None

    assert train_preprocess in available_conditioning, "Train pre-process '{}' is not available. Available functions are: '{}'".format(train_preprocess, list(available_conditioning.keys()))
    if run_validation:
        assert val_preprocess in available_conditioning, "Validation pre-process '{}' is not available. Available functions are: '{}'".format(val_preprocess, list(available_conditioning.keys()))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    writer = SummaryWriter(os.path.join(base_path, "runs", "experiment-{}".format(_run._id)))
    best_model_path = os.path.join(fs_observer.dir, "best_model.pth")
    last_model_path = os.path.join(fs_observer.dir, "last_model.pth")

    outputs_path = os.path.join(fs_observer.dir, "outputs")
    if not os.path.exists(outputs_path):
        os.mkdir(outputs_path)

    if model == "deeplab":
        model = DeepLab(num_classes=1).to(device)
    elif model == "autodeeplab":
        model = AutoDeeplab(num_classes=1).to(device)
    elif model == "unet":
        model = UNet11(pretrained=True).to(device)
    elif model == "linknet":
        model = LinkNet(n_classes=1).to(device)
    elif model == "refinenet":
        model = RefineNet4Cascade(input_shape=(3, 256)).to(device)  # 3 channels, 256x256 input
    else:
        raise Exception("Invalid model '{}'".format(model))

    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-4, step_size=500)
    loss_fn = SoftJaccardBCEWithLogitsLoss(jaccard_weight=8)

    augmentations = [
        GaussianNoise(0, 2),
        EnhanceContrast(0.5, 0.1),
        EnhanceColor(0.5, 0.1)
    ]

    dataloaders = {}
    train_preprocess_fn = available_conditioning[train_preprocess]
    val_preprocess_fn = available_conditioning[val_preprocess]

    train_dataset_args = dict(
        fpath=train_fpath,
        augmentations=augmentations,
        target_preprocess=train_preprocess_fn
    )
    validation_dataset_args = dict(
        fpath=val_fpath,
        target_preprocess=val_preprocess_fn
    )

    if multimask:
        DatasetClass = MultimaskSkinLesionSegmentationDataset
        train_dataset_args["select"] = "random"
        validation_dataset_args["select"] = "all"
    else:
        DatasetClass = SkinLesionSegmentationDataset

    train_dataset = DatasetClass(**train_dataset_args)
    dataloaders["train"] = data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           num_workers=8,
                                           shuffle=True,
                                           worker_init_fn=set_seeds)

    if run_validation:
        val_dataset = DatasetClass(**validation_dataset_args)
        dataloaders["validation"] = data.DataLoader(val_dataset,
                                                    batch_size=batch_size if not multimask else 1,
                                                    num_workers=8,
                                                    shuffle=False,
                                                    worker_init_fn=set_seeds)

    info = {}
    epochs = range(1, n_epochs + 1)
    best_jacc = 0
    epochs_since_best = 0

    for epoch in epochs:
        info["train"] = run_epoch("train", epoch, model, dataloaders["train"], optimizer, loss_fn, scheduler, writer)

        if run_validation:
            info["validation"] = run_epoch("validation", epoch, model, dataloaders["validation"], optimizer, loss_fn, scheduler, writer)
            if info["validation"]["jaccard_threshold"] > best_jacc:
                best_jacc = info["validation"]["jaccard_threshold"]
                torch.save(model, best_model_path)
                epochs_since_best = 0
            else:
                epochs_since_best += 1

        torch.save(model, last_model_path)
        writer.commit()

        if epochs_since_best > patience:
            break
예제 #4
0
class Trainer:
    def __init__(self,
                 src_vocab,
                 tgt_vocab,
                 max_len=300,
                 hidden_size=300,
                 n_layers=2,
                 clip=5,
                 n_epochs=30):
        # hyper-parameters
        self.max_len = max_len
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.clip = clip
        self.n_epochs = n_epochs

        # vocab
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.pad_idx = self.src_vocab.stoi[PAD]

        # prepare model
        self.encoder = Encoder(self.src_vocab, self.max_len, self.hidden_size,
                               self.n_layers)
        self.decoder = Decoder(self.tgt_vocab, self.max_len,
                               self.hidden_size * 2, self.n_layers)
        self.reverse_decoder = Decoder(self.tgt_vocab,
                                       self.max_len,
                                       self.hidden_size * 2,
                                       self.n_layers,
                                       reverse=True)
        self.model = Seq2SeqConcat(self.encoder, self.decoder,
                                   self.reverse_decoder, self.pad_idx)
        self.model.to(device)
        print(self.model)
        print("Total parameters:",
              sum([p.nelement() for p in self.model.parameters()]))

        # initialize weights
        for name, param in self.model.named_parameters():
            if "lstm.bias" in name:
                # set lstm forget gate to 1 (Jozefowicz et al., 2015)
                n = param.size(0)
                param.data[n // 4:n // 2].fill_(1.0)
            elif "lstm.weight" in name:
                nn.init.xavier_uniform_(param)

        # prepare loss function; don't calculate loss on PAD tokens
        self.criterion = nn.NLLLoss(ignore_index=self.pad_idx)

        # prepare optimizer and scheduler
        self.optimizer = Adam(self.model.parameters())
        self.scheduler = CyclicLR(self.optimizer,
                                  base_lr=0.00001,
                                  max_lr=0.00005,
                                  step_size_up=4000,
                                  step_size_down=4000,
                                  mode="triangular",
                                  gamma=1.0,
                                  cycle_momentum=False)

        # book keeping vars
        self.global_iter = 0
        self.global_numel = []
        self.global_loss = []
        self.global_acc = []

        # visualization
        # self.vis_loss = Visualization(env_name="aivivn_tone", xlabel="step", ylabel="loss", title="loss (mean per 300 steps)")
        # self.vis_acc = Visualization(env_name="aivivn_tone", xlabel="step", ylabel="acc", title="training accuracy (mean per 300 steps)")

    def train(self,
              train_iterator,
              val_iterator,
              start_epoch=0,
              print_every=100):
        for epoch in range(start_epoch, self.n_epochs):
            self._train_epoch(epoch,
                              train_iterator,
                              train=True,
                              print_every=print_every)
            self.save(epoch)

            # evaluate on validation set after each epoch
            with torch.no_grad():
                self._train_epoch(epoch,
                                  val_iterator,
                                  train=False,
                                  print_every=print_every)

    def train_in_parts(self,
                       train_parts,
                       val,
                       val_iterator,
                       batch_size,
                       start_epoch=0,
                       print_every=100):
        for epoch in range(start_epoch, self.n_epochs):
            # shuffle data each epoch
            random.shuffle(train_parts)

            for train_src_, train_tgt_ in train_parts:
                # create train dataset
                print("Training part [{}] with target [{}]...".format(
                    train_src_, train_tgt_))
                train_ = Seq2SeqDataset.from_file(train_src_,
                                                  train_tgt_,
                                                  share_fields_from=val)

                # create iterator
                train_iterator_ = BucketIterator(dataset=train_,
                                                 batch_size=batch_size,
                                                 sort=False,
                                                 sort_within_batch=True,
                                                 sort_key=lambda x: len(x.src),
                                                 shuffle=True,
                                                 device=device)
                # train
                self._train_epoch(epoch,
                                  train_iterator_,
                                  train=True,
                                  print_every=print_every)

                # clean
                del train_
                del train_iterator_
                gc.collect()

            # save
            self.save(epoch)

            # evaluate on validation set after each epoch
            with torch.no_grad():
                self._train_epoch(epoch,
                                  val_iterator,
                                  train=False,
                                  print_every=print_every)

    def resume(self, train_iterator, val_iterator, save_path):
        checkpoint = torch.load(save_path)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.to(device)
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        self.train(train_iterator, val_iterator, start_epoch)

    def resume_in_parts(self, train_parts, val, val_iterator, batch_size,
                        save_path):
        checkpoint = torch.load(save_path)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.to(device)
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        self.train_in_parts(train_parts,
                            val,
                            val_iterator,
                            batch_size,
                            start_epoch=start_epoch)

    def _train_epoch(self, epoch, batch_iterator, train=True, print_every=100):
        if train:
            self.model.train()
        else:
            self.model.eval()
            print("***Evaluating on validation set***")

        total_loss = 0
        total_correct = 0
        total_numel = 0
        total_iter = 0
        num_batch = len(batch_iterator)

        for i, batch in enumerate(batch_iterator):
            # forward propagation
            # (batch, seq_len, tgt_vocab_size)
            if train:
                # crude annealing teacher forcing
                teacher_forcing = 0.5
                if epoch == 0:
                    teacher_forcing = max(0.5,
                                          (num_batch - total_iter) / num_batch)
                output, reverse_output, combined_output = self.model(
                    batch, mask_softmax=0.5, teacher_forcing=teacher_forcing)
            else:
                output, reverse_output, combined_output = self.model(
                    batch, mask_softmax=1.0, teacher_forcing=1.0)

            # (batch, seq_len)
            target = getattr(batch, tgt_field_name)

            # reshape to calculate loss
            output = output.view(-1, output.size(-1))
            reverse_output = reverse_output.view(-1, reverse_output.size(-1))
            combined_output = combined_output.view(-1,
                                                   combined_output.size(-1))
            target = target.view(-1)

            # calculate loss
            loss = self.criterion(output, target) + self.criterion(
                reverse_output, target) + self.criterion(
                    combined_output, target)

            # backprop
            if train:
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
                self.optimizer.step()
                self.scheduler.step()

            # calculate accuracy
            correct = output.argmax(dim=-1).eq(target).sum().item()
            r_correct = reverse_output.argmax(dim=-1).eq(target).sum().item()
            c_correct = combined_output.argmax(dim=-1).eq(target).sum().item()

            # summarize for each batch
            total_loss += loss.item()
            total_correct += c_correct
            total_numel += target.numel()
            total_iter += 1

            # add to global summary
            if train:
                self.global_iter += 1
                self.global_numel.append(target.numel())
                self.global_loss.append(loss.item())
                self.global_acc.append(c_correct)

                # visualize
                # if self.global_iter == 1:
                #     self.vis_loss.plot_line(self.global_loss[0], 1)
                #     self.vis_acc.plot_line(self.global_acc[0]/total_numel, 1)

                # update graph every 10 iterations
                if self.global_iter % 10 == 0:
                    # moving average of most recent 300 losses
                    moving_avg_loss = sum(
                        self.global_loss[max(0,
                                             len(self.global_loss) -
                                             300):]) / min(
                                                 300.0, self.global_iter)
                    moving_avg_acc = sum(
                        self.global_acc[max(0,
                                            len(self.global_acc) - 300):]
                    ) / sum(
                        self.global_numel[max(0,
                                              len(self.global_numel) - 300):])

                    # visualize
                    # self.vis_loss.plot_line(moving_avg_loss, self.global_iter)
                    # self.vis_acc.plot_line(moving_avg_acc, self.global_iter)

            # print
            if i % print_every == 0:
                template = "epoch = {}  iter = {}  loss = {:5.3f}  correct = {:6.3f}  r_correct = {:6.3f}  c_correct = {:6.3f}"
                print(
                    template.format(epoch, i, loss.item(),
                                    correct / target.numel() * 100.0,
                                    r_correct / target.numel() * 100.0,
                                    c_correct / target.numel() * 100.0))

        # summarize for each epoch
        template = "EPOCH = {}  AVG_LOSS = {:5.3f}  AVG_CORRECT = {:6.3f}\n"
        print(
            template.format(epoch, total_loss / total_iter,
                            total_correct / total_numel * 100.0))

    def save(self, epoch, save_path="checkpoint"):
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "scheduler_state_dict": self.scheduler.state_dict(),
                "global_iter": self.global_iter
            }, os.path.join(save_path, "aivivn_tone.model.ep{}".format(epoch)))
# lr_finder = LRFinder(model, stop_factor=4)
# lr_finder.find((train_data, train_label), steps_per_epoch=steps, start_lr=1e-6, lr_mult=1.01, batch_size=num_batchsize)
# lr_finder.plot_loss()

# SGDR learning rate policy
# set learning rate range according to lr range test result
# min_lr = 5e-5
# max_lr = 2e-3
# lr_scheduler = SGDRScheduler(min_lr, max_lr, steps, lr_decay=1.0, cycle_length=1, mult_factor=2)

# one cycle learning rate policy
# set max learning rate according to lr range test result
max_lr = 1e-3
lr_scheduler = CyclicLR(base_lr=max_lr / 10,
                        max_lr=max_lr,
                        step_size=np.ceil(steps * num_epochs / 2),
                        max_momentum=0.95,
                        min_momentum=0.85)


def piecewise_constant_fn(epoch):
    if epoch < 10:
        return 1e-3
    elif epoch < 20:
        return 5e-4
    else:
        return 3e-4


#lr_scheduler = LearningRateScheduler(piecewise_constant_fn)
예제 #6
0
            criterion = nn.BCEWithLogitsLoss()

        # define model and optimizer
        if config.model_name == 'se_resnext50':
            model = SE_ResNext50(config.num_classes)
            optimizer = optim.SGD([{
                'params': model.last_linear.parameters(),
                'lr': config.base_lr[1]
            }, {
                'params': model.backbone.parameters()
            }],
                                  lr=config.base_lr[0],
                                  momentum=0.9)
            scheduler = CyclicLR(optimizer=optimizer,
                                 base_lr=config.base_lr,
                                 max_lr=config.max_lr,
                                 step_size=config.step_size,
                                 mode='triangular2')
        elif config.model_name == 'se_resnext50_spatial':
            model = SE_ResNext50_Spatial(config.num_classes)
            optimizer = optim.SGD([{
                'params': model.last_linear.parameters(),
                'lr': config.base_lr[1]
            }, {
                'params': model.backbone.parameters()
            }],
                                  lr=config.base_lr[0],
                                  momentum=0.9)
            scheduler = CyclicLR(optimizer=optimizer,
                                 base_lr=config.base_lr,
                                 max_lr=config.max_lr,
예제 #7
0
# calculate total number of threshold parameters

no_thresh_p = 0
for name, p in net.named_parameters():
    check = name.split('.')[-1]
    if check == 'thresh':
        no_thresh_p += p.numel()

lam = 1 / no_thresh_p

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=0.9,
                      weight_decay=5e-4)
scheduler = CyclicLR(optimizer)

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.t7')
    optim_ck = torch.load('./checkpoint/optim.t7')
    net.load_state_dict(checkpoint['net'])
    optimizer.load_state_dict(optim_ck['optim'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch'] + 1


# Training
def train(epoch, learn_t=False):
예제 #8
0
def main(batch_size, n_epochs, lr, beta1, decay, _run):
    assert torch.cuda.is_available()

    writer = SummaryWriter(
        os.path.join(base_path, "runs", "experiment-{}".format(_run._id)))
    model_path = os.path.join(fs_observer.dir, "best_model.pth")

    outputs_path = os.path.join(fs_observer.dir, "outputs")
    if not os.path.exists(outputs_path):
        os.mkdir(outputs_path)

    cudnn.benchmark = True
    s_model = SegmentorNet().cuda()
    c_model = CriticNet().cuda()

    s_optimizer = optim.Adam(s_model.parameters(), lr=lr, betas=(beta1, 0.999))
    c_optimizer = optim.Adam(c_model.parameters(), lr=lr, betas=(beta1, 0.999))

    s_scheduler = CyclicLR(s_optimizer, base_lr=lr, max_lr=lr * 10)
    c_scheduler = CyclicLR(c_optimizer, base_lr=lr, max_lr=lr * 10)

    dataloaders = {
        "train": loader(Dataset('./'), batch_size),
        "validation": loader(Dataset_val('./'), 36)
    }

    best_IoU = 0.0
    s_model.train()
    for epoch in range(n_epochs):
        progress_bar = tqdm(dataloaders["train"],
                            desc="Epoch {} - train".format(epoch))

        s_losses = []
        s_losses_joint = []
        c_losses = []
        dices = []

        for i, (inputs, targets) in enumerate(progress_bar):
            c_model.zero_grad()

            inputs = Variable(inputs).cuda()
            targets = Variable(targets).cuda().type(torch.FloatTensor).cuda()

            outputs = s_model(inputs)
            outputs = F.sigmoid(outputs)
            outputs = outputs.detach()
            outputs_masked = inputs.clone()
            inputs_mask = inputs.clone()

            for d in range(3):
                outputs_masked[:, d, :, :] = inputs_mask[:, d, :, :].unsqueeze(
                    1) * outputs
            outputs_masked = outputs_masked.cuda()

            results = c_model(outputs_masked)
            targets_masked = inputs.clone()
            for d in range(3):
                targets_masked[:, d, :, :] = inputs_mask[:, d, :, :].unsqueeze(
                    1) * targets

            for d in range(3):
                targets_masked[:, d, :, :] = inputs_mask[:, d, :, :].unsqueeze(
                    1) * targets

            targets_masked = targets_masked.cuda()
            targets_D = c_model(targets_masked)
            loss_D = -torch.mean(torch.abs(results - targets_D))
            loss_D.backward()
            c_optimizer.step()
            c_scheduler.batch_step()

            for p in c_model.parameters():
                p.data.clamp_(-0.05, 0.05)

            s_model.zero_grad()
            outputs = s_model(inputs)
            outputs = F.sigmoid(outputs)

            for d in range(3):
                outputs_masked[:, d, :, :] = inputs_mask[:, d, :, :].unsqueeze(
                    1) * outputs
            outputs_masked = outputs_masked.cuda()

            results = c_model(outputs_masked)
            for d in range(3):
                targets_masked[:, d, :, :] = inputs_mask[:, d, :, :].unsqueeze(
                    1) * targets
            targets_masked = targets_masked.cuda()

            targets_G = c_model(targets_masked)
            loss_dice = dice_loss(outputs, targets)
            loss_G = torch.mean(torch.abs(results - targets_G))
            loss_G_joint = loss_G + loss_dice
            loss_G_joint.backward()
            s_optimizer.step()
            s_scheduler.batch_step()

            c_losses.append(loss_D.data[0])
            s_losses.append(loss_G.data[0])
            s_losses_joint.append(loss_G_joint.data[0])
            dices.append(loss_dice.data[0])

            progress_bar.set_postfix(
                OrderedDict({
                    "c_loss": np.mean(c_losses),
                    "s_loss": np.mean(s_losses),
                    "s_loss_joint": np.mean(s_losses_joint),
                    "dice": np.mean(dices)
                }))

        mean_c_loss = np.mean(c_losses)
        mean_s_loss = np.mean(s_losses)
        mean_s_loss_joint = np.mean(s_losses_joint)
        mean_dice = np.mean(dices)

        c_loss_tag = "train.c_loss"
        s_loss_tag = "train.s_loss"
        s_losses_joint_tag = "train.s_loss_joint"
        dice_loss_tag = "train.loss_dice"

        writer.add_scalar(c_loss_tag, mean_c_loss, epoch)
        writer.add_scalar(s_loss_tag, mean_s_loss, epoch)
        writer.add_scalar(s_losses_joint_tag, mean_s_loss_joint, epoch)
        writer.add_scalar(dice_loss_tag, mean_dice, epoch)

        if epoch % 10 == 0:
            progress_bar = tqdm(dataloaders["validation"],
                                desc="Epoch {} - validation".format(epoch))

            s_model.eval()
            IoUs, dices = [], []
            for i, (inputs, targets) in enumerate(progress_bar):
                inputs = Variable(inputs).cuda()
                targets = Variable(targets).cuda()

                pred = s_model(inputs)
                pred[pred < 0.5] = 0
                pred[pred >= 0.5] = 1

                pred = pred.type(torch.LongTensor)
                pred_np = pred.data.cpu().numpy()

                targets = targets.data.cpu().numpy()
                for x in range(inputs.size()[0]):
                    IoU = np.sum(pred_np[x][targets[x] == 1]) / float(
                        np.sum(pred_np[x]) + np.sum(targets[x]) -
                        np.sum(pred_np[x][targets[x] == 1]))
                    dice = np.sum(pred_np[x][targets[x] == 1]) * 2 / float(
                        np.sum(pred_np[x]) + np.sum(targets[x]))
                    IoUs.append(IoU)
                    dices.append(dice)

                progress_bar.set_postfix(
                    OrderedDict({
                        "mIoU": np.mean(IoUs, axis=0),
                        "mDice": np.mean(dices, axis=0)
                    }))

            s_model.train()
            IoUs = np.array(IoUs, dtype=np.float64)
            dices = np.array(dices, dtype=np.float64)
            mIoU = np.mean(IoUs, axis=0)
            mDice = np.mean(dices, axis=0)

            miou_tag = "validation.miou"
            mdice_tag = "validation.mdice"

            writer.add_scalar(miou_tag, mIoU, epoch)
            writer.add_scalar(mdice_tag, mDice, epoch)
            writer.commit()

            if mIoU > best_IoU:
                best_IoU = mIoU
                torch.save(s_model, model_path)

        if epoch % 25 == 0:
            lr = max(lr * decay, 0.00000001)
            s_optimizer = optim.Adam(s_model.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))
            c_optimizer = optim.Adam(c_model.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))