コード例 #1
0
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion,
                      optimizer):
    model.train()
    train_meter = Meter(args['train_mean'], args['train_std'])
    epoch_loss = 0
    for batch_id, batch_data in enumerate(data_loader):
        indices, ligand_mols, protein_mols, bg, labels = batch_data
        labels, bg = labels.to(args['device']), bg.to(args['device'])
        prediction = model(bg)
        loss = loss_criterion(prediction, (labels - args['train_mean']) /
                              args['train_std'])
        epoch_loss += loss.data.item() * len(indices)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_meter.update(prediction, labels)
    avg_loss = epoch_loss / len(data_loader.dataset)
    total_scores = {
        metric: train_meter.compute_metric(metric)
        for metric in args['metrics']
    }
    msg = 'epoch {:d}/{:d}, training | loss {:.4f}'.format(
        epoch + 1, args['num_epochs'], avg_loss)
    msg = update_msg_from_scores(msg, total_scores)
    print(msg)
コード例 #2
0
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion,
                      optimizer):
    model.train()
    train_meter = Meter()
    for batch_id, batch_data in enumerate(data_loader):
        smiles, bg, labels, masks = batch_data
        atom_feats = bg.ndata.pop(args['atom_data_field'])
        atom_feats, labels, masks = atom_feats.to(args['device']), \
                                    labels.to(args['device']), \
                                    masks.to(args['device'])
        logits = model(bg, atom_feats)
        # Mask non-existing labels
        loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
            epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader),
            loss.item()))
        train_meter.update(logits, labels, masks)
    train_score = np.mean(train_meter.compute_metric(args['metric_name']))
    print('epoch {:d}/{:d}, training {} {:.4f}'.format(epoch + 1,
                                                       args['num_epochs'],
                                                       args['metric_name'],
                                                       train_score))
コード例 #3
0
    def _iterate(self, epoch, phase):
        meter = Meter(phase, epoch)
        start = time.strftime('%H:%M:%S')
        print("Starting epoch: {} | phase: {} | Time: {}".format(
            epoch + 1, phase, start))
        dl = self.dataloaders[phase]
        running_loss = 0.0
        total_steps = len(dl)
        self.optimizer.zero_grad()
        for itr, sample in enumerate(tqdm(dl)):
            images = sample['image']
            targets = sample['mask']
            loss, outputs = self._forward(images, targets)
            loss /= self.accumlation_steps
            if phase == 'train':
                loss.backward()
                if (itr + 1) % self.accumlation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
            running_loss += loss.item()
            outputs = outputs.detach().cpu()
            meter.update(targets, outputs)
        epoch_loss = (running_loss * self.accumlation_steps) / total_steps
        dice, iou = epoch_log(phase, epoch, epoch_loss, meter, start)
        visualize(sample, outputs, epoch, phase)

        self.losses[phase].append(epoch_loss)
        self.dice_scores[phase].append(dice)
        self.iou_scores[phase].append(iou)

        return epoch_loss
コード例 #4
0
def run_an_eval_epoch(args, model, data_loader):
    model.eval()
    eval_meter = Meter()
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            smiles, bg, labels, mask = batch_data
            atom_feats = bg.ndata.pop(args['atom_data_field'])
            atom_feats, labels = atom_feats.to(args['device']), labels.to(args['device'])
            logits = model(bg, atom_feats)
            eval_meter.update(logits, labels, mask)
    return eval_meter.roc_auc_averaged_over_tasks()
コード例 #5
0
ファイル: regression.py プロジェクト: zergey/dgl
def run_an_eval_epoch(args, model, data_loader):
    model.eval()
    eval_meter = Meter()
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            smiles, bg, labels, masks = batch_data
            labels = labels.to(args['device'])
            prediction = regress(args, model, bg)
            eval_meter.update(prediction, labels, masks)
        total_score = np.mean(eval_meter.compute_metric(args['metric_name']))
    return total_score
コード例 #6
0
def run_an_eval_epoch(args, model, data_loader):
    model.eval()
    eval_meter = Meter()
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            smiles, bg, labels, masks = batch_data
            atom_feats = bg.ndata.pop(args['atom_data_field'])
            atom_feats, labels = atom_feats.to(args['device']), labels.to(
                args['device'])
            logits = model(bg, atom_feats)
            eval_meter.update(logits, labels, masks)
    return np.mean(eval_meter.compute_metric(args['metric_name']))
コード例 #7
0
    def __init__(self,
                 data_loader,
                 model,
                 optimizer,
                 loss_fn,
                 debug=False,
                 cuda=False,
                 checkpoint_dir='checkpoints',
                 best_model_filename='best_model.pt'):
        self._data_loader = data_loader
        self._loss_fn = loss_fn

        self.data_loader = None
        self.loss_fn = None

        self.model = model
        self.optimizer = optimizer

        self.visualizer = SegmentationVisualizer()

        self.train_loss_meter = Meter('Loss/train')
        self.train_iou_meter = Meter('IoU/train')

        self.val_loss_meter = Meter('Loss/val')
        self.val_iou_meter = Meter('IoU/val')

        self.checkpoint_dir = checkpoint_dir
        self.best_model_filename = best_model_filename

        self.debug = debug
        self.cuda = cuda

        self.best_iou = 0

        self._set_epoch(0)
コード例 #8
0
def run_an_eval_epoch(args, model, data_loader):
    model.eval()
    eval_meter = Meter(args['train_mean'], args['train_std'])
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            indices, ligand_mols, protein_mols, bg, labels = batch_data
            labels, bg = labels.to(args['device']), bg.to(args['device'])
            prediction = model(bg)
            eval_meter.update(prediction, labels)
    total_scores = {
        metric: eval_meter.compute_metric(metric)
        for metric in args['metrics']
    }
    return total_scores
コード例 #9
0
def evaluate(data_loader):
    meter = Meter('eval', 0)
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for idx, (img, segm) in enumerate(data_loader):
            img = img.cuda() 
            segm = segm.cuda() 
            outputs = model(img) 
            loss = criterion(outputs, segm)
            outputs = outputs.detach().cpu()
            segm = segm.detach().cpu() 
            meter.update(segm, outputs) 
            total_loss += loss.item()
        dices, iou = meter.get_metrics() 
        dice, dice_neg, dice_pos = dices
        torch.cuda.empty_cache()

        return total_loss/len(data_loader), iou, dice, dice_neg, dice_pos
コード例 #10
0
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion,
                      optimizer):
    model.train()
    train_meter = Meter()
    for batch_id, batch_data in enumerate(data_loader):
        smiles, bg, labels, masks = batch_data
        labels, masks = labels.to(args['device']), masks.to(args['device'])
        prediction = regress(args, model, bg)
        loss = (loss_criterion(prediction, labels) *
                (masks != 0).float()).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_meter.update(prediction, labels, masks)
    total_score = np.mean(train_meter.compute_metric(args['metric_name']))
    print('epoch {:d}/{:d}, training {} {:.4f}'.format(epoch + 1,
                                                       args['num_epochs'],
                                                       args['metric_name'],
                                                       total_score))
コード例 #11
0
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
    model.train()
    train_meter = Meter()
    for batch_id, batch_data in enumerate(data_loader):
        smiles, bg, labels, mask = batch_data
        atom_feats = bg.ndata.pop(args['atom_data_field'])
        atom_feats, labels, mask = atom_feats.to(args['device']), \
                                   labels.to(args['device']), \
                                   mask.to(args['device'])
        logits = model(bg, atom_feats)
        # Mask non-existing labels
        loss = (loss_criterion(logits, labels) * (mask != 0).float()).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
            epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), loss.item()))
        train_meter.update(logits, labels, mask)
    train_roc_auc = train_meter.roc_auc_averaged_over_tasks()
    print('epoch {:d}/{:d}, training roc-auc score {:.4f}'.format(
        epoch + 1, args['num_epochs'], train_roc_auc))
コード例 #12
0
ファイル: gan_train.py プロジェクト: abhikesh113/speech2voice
face_iterator = iter(cycle(face_loader))

# networks, Fe, Fg, Fd (f+d), Fc (f+c)
print('Initializing networks...')
e_net, e_optimizer = get_network('e', NETWORKS_PARAMETERS, train=False)
g_net, g_optimizer = get_network('g', NETWORKS_PARAMETERS, train=True)
f_net, f_optimizer = get_network('f', NETWORKS_PARAMETERS, train=True)
d_net, d_optimizer = get_network('d', NETWORKS_PARAMETERS, train=True)
c_net, c_optimizer = get_network('c', NETWORKS_PARAMETERS, train=True)

# label for real/fake faces
real_label = torch.full((DATASET_PARAMETERS['batch_size'], 1), 1)
fake_label = torch.full((DATASET_PARAMETERS['batch_size'], 1), 0)

# Meters for recording the training status
iteration = Meter('Iter', 'sum', ':5d')
data_time = Meter('Data', 'sum', ':4.2f')
batch_time = Meter('Time', 'sum', ':4.2f')
D_real = Meter('D_real', 'avg', ':3.2f')
D_fake = Meter('D_fake', 'avg', ':3.2f')
C_real = Meter('C_real', 'avg', ':3.2f')
GD_fake = Meter('G_D_fake', 'avg', ':3.2f')
GC_fake = Meter('G_C_fake', 'avg', ':3.2f')

print('Training models...')
for it in range(50000):
    # data
    start_time = time.time()

    voice, voice_label = next(voice_iterator)
    face, face_label = next(face_iterator)
コード例 #13
0
    else:
        raise ValueError('Model type is not correct: `{}`.'.format(
            MODEL["mode"]))

    device = torch.device(EVAL["device"])
    model.to(device)
    model.eval()
    state = torch.load(EVAL["model_path"],
                       map_location=lambda storage, loc: storage)
    model.load_state_dict(state["state_dict"])

    if EVAL["apply_tta"]:
        TTAModel = TTAWrapper(model, merge_mode="mean")

    if not EVAL["test_mode"]:
        meter = Meter(base_threshold=EVAL["base_threshold"],
                      get_class_metric=True)

    images_path = EVAL["eval_images_path"] if not EVAL["test_mode"] else EVAL[
        "test_images_path"]
    try:
        shutil.rmtree(images_path)
    except:
        pass
    os.mkdir(images_path)

    start = time.time()
    for batch in tqdm(dataloader):
        images, targets, image_id = batch

        images = images.to(device)
        if EVAL["apply_tta"]:
コード例 #14
0
def train_model(model,
                train_loader, dev_loader,
                optimizer, criterion,
                num_classes, target_classes,
                label_encoder,
                device):

    # create to Meter's classes to track the performance of the model during training and evaluating
    train_meter = Meter(target_classes)
    dev_meter = Meter(target_classes)

    best_f1 = -1

    # epoch loop
    for epoch in range(args.epochs):
        train_tqdm = tqdm(train_loader)
        dev_tqdm = tqdm(dev_loader)

        model.train()

        # train loop
        for i, (train_x, train_y, mask, crf_mask) in enumerate(train_tqdm):
            # get the logits and update the gradients
            optimizer.zero_grad()

            logits = model.forward(train_x, mask)

            if args.no_crf:
                loss = criterion(logits.reshape(-1, num_classes).to(device), train_y.reshape(-1).to(device))
            else:
                loss = - criterion(logits.to(device), train_y, reduction="token_mean", mask=crf_mask)

            loss.backward()
            optimizer.step()

            # get the current metrics (average over all the train)
            loss, _, _, micro_f1, _, _, macro_f1 = train_meter.update_params(loss.item(), logits, train_y)

            # print the metrics
            train_tqdm.set_description("Epoch: {}/{}, Train Loss: {:.4f}, Train Micro F1: {:.4f}, Train Macro F1: {:.4f}".
                                       format(epoch + 1, args.epochs, loss, micro_f1, macro_f1))
            train_tqdm.refresh()

        # reset the metrics to 0
        train_meter.reset()

        model.eval()

        # evaluation loop -> mostly same as the training loop, but without updating the parameters
        for i, (dev_x, dev_y, mask, crf_mask) in enumerate(dev_tqdm):
            logits = model.forward(dev_x, mask)

            if args.no_crf:
                loss = criterion(logits.reshape(-1, num_classes).to(device), dev_y.reshape(-1).to(device))
            else:
                loss = - criterion(logits.to(device), dev_y, reduction="token_mean", mask=crf_mask)

            loss, _, _, micro_f1, _, _, macro_f1 = dev_meter.update_params(loss.item(), logits, dev_y)

            dev_tqdm.set_description("Dev Loss: {:.4f}, Dev Micro F1: {:.4f}, Dev Macro F1: {:.4f}".
                                     format(loss, micro_f1, macro_f1))
            dev_tqdm.refresh()

        dev_meter.reset()

        # if the current macro F1 score is the best one -> save the model
        if macro_f1 > best_f1:
            if not os.path.exists(args.save_path):
                os.makedirs(args.save_path)

            print("Macro F1 score improved from {:.4f} -> {:.4f}. Saving model...".format(best_f1, macro_f1))

            best_f1 = macro_f1
            torch.save(model, os.path.join(args.save_path, "model.pt"))
            with open(os.path.join(args.save_path, "label_encoder.pk"), "wb") as file:
                pickle.dump(label_encoder, file)
コード例 #15
0
ファイル: classification.py プロジェクト: tmacmilan/dgl
def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 128
    learning_rate = 0.001
    num_epochs = 100
    set_random_seed()

    # Interchangeable with other Dataset
    dataset = Tox21()
    atom_data_field = 'h'

    trainset, valset, testset = split_dataset(dataset, [0.8, 0.1, 0.1])
    train_loader = DataLoader(
        trainset, batch_size=batch_size, collate_fn=collate_molgraphs)
    val_loader = DataLoader(
        valset, batch_size=batch_size, collate_fn=collate_molgraphs)
    test_loader = DataLoader(
        testset, batch_size=batch_size, collate_fn=collate_molgraphs)

    if args.pre_trained:
        num_epochs = 0
        model = model_zoo.chem.load_pretrained('GCN_Tox21')
    else:
        # Interchangeable with other models
        model = model_zoo.chem.GCNClassifier(in_feats=74,
                                             gcn_hidden_feats=[64, 64],
                                             n_tasks=dataset.n_tasks)
        loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(
            dataset.task_pos_weights).to(device), reduction='none')
        optimizer = Adam(model.parameters(), lr=learning_rate)
        stopper = EarlyStopping(patience=10)
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        print('Start training')
        train_meter = Meter()
        for batch_id, batch_data in enumerate(train_loader):
            smiles, bg, labels, mask = batch_data
            atom_feats = bg.ndata.pop(atom_data_field)
            atom_feats, labels, mask = atom_feats.to(device), labels.to(device), mask.to(device)
            logits = model(atom_feats, bg)
            # Mask non-existing labels
            loss = (loss_criterion(logits, labels)
                    * (mask != 0).float()).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('epoch {:d}/{:d}, batch {:d}/{:d}, loss {:.4f}'.format(
                epoch + 1, num_epochs, batch_id + 1, len(train_loader), loss.item()))
            train_meter.update(logits, labels, mask)
        train_roc_auc = train_meter.roc_auc_averaged_over_tasks()
        print('epoch {:d}/{:d}, training roc-auc score {:.4f}'.format(
            epoch + 1, num_epochs, train_roc_auc))
        
        val_meter = Meter()
        model.eval()
        with torch.no_grad():
            for batch_id, batch_data in enumerate(val_loader):
                smiles, bg, labels, mask = batch_data
                atom_feats = bg.ndata.pop(atom_data_field)
                atom_feats, labels = atom_feats.to(device), labels.to(device)
                logits = model(atom_feats, bg)
                val_meter.update(logits, labels, mask)
        
        val_roc_auc = val_meter.roc_auc_averaged_over_tasks()
        if stopper.step(val_roc_auc, model):
            break

        print('epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'.format(
            epoch + 1, num_epochs, val_roc_auc, stopper.best_score))

    test_meter = Meter()
    model.eval()
    for batch_id, batch_data in enumerate(test_loader):
        smiles, bg, labels, mask = batch_data
        atom_feats = bg.ndata.pop(atom_data_field)
        atom_feats, labels = atom_feats.to(device), labels.to(device)
        logits = model(atom_feats, bg)
        test_meter.update(logits, labels, mask)
    print('test roc-auc score {:.4f}'.format(test_meter.roc_auc_averaged_over_tasks()))
コード例 #16
0
def train_model(model, train_loader, dev_loader, optimizer, criterion,
                num_classes, target_classes, it, label_encoder, device):

    # create to Meter's classes to track the performance of the model during training and evaluating
    train_meter = Meter(target_classes)
    dev_meter = Meter(target_classes)

    best_f1 = 0
    loss, macro_f1 = 0, 0

    total_steps = len(train_loader) * args.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,  # Default value in run_glue.py
        num_training_steps=total_steps)

    curr_patience = 0

    # epoch loop
    for epoch in range(args.epochs):
        train_tqdm = tqdm(train_loader, leave=False)

        model.train()

        # train loop
        for i, (train_x, train_y, mask) in enumerate(train_tqdm):
            train_tqdm.set_description(
                "    Training - Epoch: {}/{}, Loss: {:.4f}, F1: {:.4f}, Best F1: {:.4f}"
                .format(epoch + 1, args.epochs, loss, macro_f1, best_f1))
            train_tqdm.refresh()

            # get the logits and update the gradients
            optimizer.zero_grad()

            logits = model.forward(train_x, mask)

            loss = criterion(
                logits.reshape(-1, num_classes).to(device),
                train_y.reshape(-1).to(device))
            loss.backward()
            optimizer.step()

            if args.fine_tune:
                scheduler.step()

            # get the current metrics (average over all the train)
            loss, _, _, _, _, _, macro_f1 = train_meter.update_params(
                loss.item(), logits, train_y)

        # reset the metrics to 0
        train_meter.reset()

        dev_tqdm = tqdm(dev_loader, leave=False)
        model.eval()
        loss, macro_f1 = 0, 0

        # evaluation loop -> mostly same as the training loop, but without updating the parameters
        for i, (dev_x, dev_y, mask) in enumerate(dev_tqdm):
            dev_tqdm.set_description(
                "    Evaluating - Epoch: {}/{}, Loss: {:.4f}, F1: {:.4f}, Best F1: {:.4f}"
                .format(epoch + 1, args.epochs, loss, macro_f1, best_f1))
            dev_tqdm.refresh()

            logits = model.forward(dev_x, mask)
            loss = criterion(
                logits.reshape(-1, num_classes).to(device),
                dev_y.reshape(-1).to(device))

            loss, _, _, micro_f1, _, _, macro_f1 = dev_meter.update_params(
                loss.item(), logits, dev_y)

        dev_meter.reset()

        # if the current macro F1 score is the best one -> save the model
        if macro_f1 > best_f1:
            curr_patience = 0
            best_f1 = macro_f1
            torch.save(
                model,
                os.path.join(args.save_path, "model_{}.pt".format(it + 1)))
            with open(os.path.join(args.save_path, "label_encoder.pk"),
                      "wb") as file:
                pickle.dump(label_encoder, file)
        else:
            curr_patience += 1

        if curr_patience > args.patience:
            break

    return best_f1
コード例 #17
0
c2_net, c2_optimizer = get_network('c', NETWORKS_PARAMETERS, train=True)

# 接力训练,载入已有的模型
if NETWORKS_PARAMETERS['finetune']:
    restore_train(g_net, d1_net, f1_net, f2_net)

# label for real/fake faces
real_label = torch.full((DATASET_PARAMETERS['batch_size'], 1), 1)
fake_label = torch.full((DATASET_PARAMETERS['batch_size'], 1), 0)
D_loss_positive = torch.tensor(1, dtype=torch.float)
D_loss_negative = D_loss_positive * -1

#  Meters for recording the training status 日志模块 #
writer = SummaryWriter("./models/log")
logger = Logger(DATASET_PARAMETERS['log_dir'], time.strftime("%Y-%m-%d,%H,%M"))
iteration = Meter('Iter', 'sum', ':5d')
data_time = Meter('Data', 'sum', ':4.2f')
batch_time = Meter('Time', 'sum', ':4.2f')
D_real = Meter('D_real', 'avg', ':4.3f')
D_fake = Meter('D_fake', 'avg', ':4.3f')
C1_real = Meter('C1_real', 'avg', ':4.3f')
C2_real = Meter('C2_real', 'avg', ':4.3f')
C1_fake = Meter('C1_fake', 'avg', ':4.3f')
C2_fake = Meter('C2_fake', 'avg', ':4.3f')
GD_fake = Meter('G_D_fake', 'avg', ':4.3f')

print('Training models...')
for it in range(90000 + 1):
    # data
    start_time = time.time()
    voice, voice_identity_label, voice_emotion_label = next(voice_iterator)
コード例 #18
0
class Trainer(object):
    def __init__(self,
                 data_loader,
                 model,
                 optimizer,
                 loss_fn,
                 debug=False,
                 cuda=False,
                 checkpoint_dir='checkpoints',
                 best_model_filename='best_model.pt'):
        self._data_loader = data_loader
        self._loss_fn = loss_fn

        self.data_loader = None
        self.loss_fn = None

        self.model = model
        self.optimizer = optimizer

        self.visualizer = SegmentationVisualizer()

        self.train_loss_meter = Meter('Loss/train')
        self.train_iou_meter = Meter('IoU/train')

        self.val_loss_meter = Meter('Loss/val')
        self.val_iou_meter = Meter('IoU/val')

        self.checkpoint_dir = checkpoint_dir
        self.best_model_filename = best_model_filename

        self.debug = debug
        self.cuda = cuda

        self.best_iou = 0

        self._set_epoch(0)

    def _set_epoch(self, epoch):
        if epoch in self._data_loader:
            print('Switching data loaders')
            self.data_loader = self._data_loader[epoch]
        if epoch in self._loss_fn:
            print('Switching loss function')
            self.loss_fn = self._loss_fn[epoch]

    def train_one_epoch(self, epoch):
        self._set_epoch(epoch)

        if self.cuda and torch.cuda.is_initialized():
            self.model = self.model.cuda()
            self.loss_fn = self.loss_fn.cuda()

        self.model.train()
        self.train_loss_meter.reset()
        self.train_iou_meter.reset()

        for i, (src,
                dst) in enumerate(tqdm(self.data_loader['train'],
                                       leave=False)):
            if self.cuda and torch.cuda.is_initialized():
                dst = dst.cuda(non_blocking=True)
                src = src.cuda(non_blocking=True)

            self.optimizer.zero_grad()

            y_head = self.model(src)

            loss = self.loss_fn(y_head, dst)

            loss.backward()
            self.optimizer.step()

            self.train_loss_meter(loss.item())

            self.train_iou_meter(
                iou_binary((y_head.detach() > 0), dst.detach()))

            if i % 100 == 0:
                step = epoch * len(self.data_loader['train']) + i
                data = {
                    'loss': self.train_loss_meter.value(),
                    'accuracy': self.train_iou_meter.value()
                }
                self.visualizer.add_scalars(data, step, prefix='train_')
                if self.debug and i == 0:
                    images = {
                        'images': src,
                        'gt_masks': dst,
                        'masks': y_head.detach() > 0
                    }
                    self.visualizer.add_images(images, epoch, prefix='train_')
        print(
            f'\tFinal {self.train_loss_meter.name}:\t{self.train_loss_meter.mean():.4f}\t',
            f'final {self.train_iou_meter.name}:\t{self.train_iou_meter.mean():.4f}'
        )

    def validate(self, epoch):
        self._set_epoch(epoch)
        self.model.eval()
        self.val_loss_meter.reset()
        self.val_iou_meter.reset()
        for i, (src,
                dst) in enumerate(tqdm(self.data_loader['val'], leave=False)):
            if self.cuda and torch.cuda.is_available():
                dst = dst.cuda(non_blocking=True)
                src = src.cuda(non_blocking=True)

            with torch.no_grad():
                y_head = self.model(src)

            loss = self.loss_fn(y_head, dst)
            self.val_loss_meter(loss.item())

            self.val_iou_meter(iou_binary(y_head.detach() > 0, dst.detach()))
            if self.debug and i == 0 and epoch % 50 == 0:
                images = {
                    'images': src,
                    'gt_masks': dst,
                    'masks': y_head.detach() > 0
                }
                self.visualizer.add_images(images, epoch, prefix='val_')

        data = {
            'loss': self.val_loss_meter.mean(),
            'accuracy': self.val_iou_meter.mean()
        }
        self.visualizer.add_scalars(data, epoch, prefix='val_')
        print(
            f'\tFinal {self.val_loss_meter.name}:\t\t{self.val_loss_meter.mean():.4f}\t',
            f'final {self.val_iou_meter.name}:\t\t{self.val_iou_meter.mean():.4f}'
        )

        self.save_best_model()

    @property
    def best_model_checkpoint_filepath(self):
        return osp.join(self.checkpoint_dir, self.best_model_filename)

    def load_previous_best_model(self):
        device = torch.device('cpu')
        state_dict = torch.load(self.best_model_checkpoint_filepath,
                                map_location=device)
        self.model.load_state_dict(state_dict)

    def save_best_model(self):
        if self.val_iou_meter.mean() > self.best_iou:
            print(
                f'Updating best model @{self.val_iou_meter.name}:{self.val_iou_meter.mean():.04f}'
            )
            self.best_iou = self.val_iou_meter.mean()
            torch.save(self.model.state_dict(),
                       self.best_model_checkpoint_filepath)