Ejemplo n.º 1
0
def train(render_count=-1):
    dataloaders_head_A, dataloaders_head_B, \
    mapping_assignment_dataloader, mapping_test_dataloader = \
      cluster_twohead_create_dataloaders(config)

    net = archs.__dict__[config.arch](config)
    if config.restart:
        model_path = os.path.join(config.out_dir, net_name)
        net.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))

    net.cuda()
    net = torch.nn.DataParallel(net)
    net.train()

    optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr)
    if config.restart:
        print("loading latest opt")
        optimiser.load_state_dict(
            torch.load(os.path.join(config.out_dir, opt_name)))

    heads = ["B", "A"]
    if config.head_A_first:
        heads = ["A", "B"]

    head_epochs = {}
    head_epochs["A"] = config.head_A_epochs
    head_epochs["B"] = config.head_B_epochs

    # Results
    # ----------------------------------------------------------------------

    if config.restart:
        if not config.restart_from_best:
            next_epoch = config.last_epoch + 1  # corresponds to last saved model
        else:
            next_epoch = np.argmax(np.array(config.epoch_acc)) + 1
        print("starting from epoch %d" % next_epoch)

        # in case we overshot without saving
        config.epoch_acc = config.epoch_acc[:next_epoch]  # in case we overshot
        config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:
                                                                    next_epoch]
        config.epoch_stats = config.epoch_stats[:next_epoch]

        if config.double_eval:
            config.double_eval_acc = config.double_eval_acc[:next_epoch]
            config.double_eval_avg_subhead_acc = config.double_eval_avg_subhead_acc[:
                                                                                    next_epoch]
            config.double_eval_stats = config.double_eval_stats[:next_epoch]

        config.epoch_loss_head_A = config.epoch_loss_head_A[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_A = config.epoch_loss_no_lamb_head_A[:(
            next_epoch - 1)]

        config.epoch_loss_head_B = config.epoch_loss_head_B[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_B = config.epoch_loss_no_lamb_head_B[:(
            next_epoch - 1)]
    else:
        config.epoch_acc = []
        config.epoch_avg_subhead_acc = []
        config.epoch_stats = []

        if config.double_eval:
            config.double_eval_acc = []
            config.double_eval_avg_subhead_acc = []
            config.double_eval_stats = []

        config.epoch_loss_head_A = []
        config.epoch_loss_no_lamb_head_A = []

        config.epoch_loss_head_B = []
        config.epoch_loss_no_lamb_head_B = []

        sub_head = None
        if config.select_sub_head_on_loss:
            sub_head = get_subhead_using_loss(config,
                                              dataloaders_head_B,
                                              net,
                                              sobel=False,
                                              lamb=config.lamb_B)
        _ = cluster_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=False,
            use_sub_head=sub_head)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        if config.double_eval:
            print("double eval: \n %s" % (nice(config.double_eval_stats[-1])))
        sys.stdout.flush()
        next_epoch = 1

    fig, axarr = plt.subplots(6 + 2 * int(config.double_eval),
                              sharex=False,
                              figsize=(20, 20))

    save_progression = hasattr(config, "save_progression") and \
                       config.save_progression
    if save_progression:
        save_progression_count = 0
        save_progress(config,
                      net,
                      mapping_assignment_dataloader,
                      mapping_test_dataloader,
                      save_progression_count,
                      sobel=False,
                      render_count=render_count)
        save_progression_count += 1

    # Train
    # ------------------------------------------------------------------------

    for e_i in xrange(next_epoch, config.num_epochs):
        print("Starting e_i: %d" % e_i)

        if e_i in config.lr_schedule:
            optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

        for head_i in range(2):
            head = heads[head_i]
            if head == "A":
                dataloaders = dataloaders_head_A
                epoch_loss = config.epoch_loss_head_A
                epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A
                lamb = config.lamb_A
            elif head == "B":
                dataloaders = dataloaders_head_B
                epoch_loss = config.epoch_loss_head_B
                epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B
                lamb = config.lamb_B

            avg_loss = 0.  # over heads and head_epochs (and sub_heads)
            avg_loss_no_lamb = 0.
            avg_loss_count = 0

            for head_i_epoch in range(head_epochs[head]):
                sys.stdout.flush()

                iterators = (d for d in dataloaders)

                b_i = 0
                for tup in itertools.izip(*iterators):
                    net.module.zero_grad()

                    all_imgs = torch.zeros(
                        (config.batch_sz, config.in_channels, config.input_sz,
                         config.input_sz)).cuda()
                    all_imgs_tf = torch.zeros(
                        (config.batch_sz, config.in_channels, config.input_sz,
                         config.input_sz)).cuda()

                    imgs_curr = tup[0][0]  # always the first
                    curr_batch_sz = imgs_curr.size(0)
                    for d_i in xrange(config.num_dataloaders):
                        imgs_tf_curr = tup[1 + d_i][0]  # from 2nd to last
                        assert (curr_batch_sz == imgs_tf_curr.size(0))

                        actual_batch_start = d_i * curr_batch_sz
                        actual_batch_end = actual_batch_start + curr_batch_sz
                        all_imgs[actual_batch_start:actual_batch_end, :, :, :] = \
                          imgs_curr.cuda()
                        all_imgs_tf[actual_batch_start:actual_batch_end, :, :, :] = \
                          imgs_tf_curr.cuda()

                    if not (curr_batch_sz == config.dataloader_batch_sz):
                        print("last batch sz %d" % curr_batch_sz)

                    curr_total_batch_sz = curr_batch_sz * config.num_dataloaders  #
                    # times 2
                    all_imgs = all_imgs[:curr_total_batch_sz, :, :, :]
                    all_imgs_tf = all_imgs_tf[:curr_total_batch_sz, :, :, :]

                    x_outs = net(all_imgs)
                    x_tf_outs = net(all_imgs_tf)

                    avg_loss_batch = None  # avg over the heads
                    avg_loss_no_lamb_batch = None
                    for i in xrange(config.num_sub_heads):
                        loss, loss_no_lamb = IID_loss(x_outs[i],
                                                      x_tf_outs[i],
                                                      lamb=lamb)
                        if avg_loss_batch is None:
                            avg_loss_batch = loss
                            avg_loss_no_lamb_batch = loss_no_lamb
                        else:
                            avg_loss_batch += loss
                            avg_loss_no_lamb_batch += loss_no_lamb

                    avg_loss_batch /= config.num_sub_heads
                    avg_loss_no_lamb_batch /= config.num_sub_heads

                    if ((b_i % 100) == 0) or (e_i == next_epoch):
                        print(
                          "Model ind %d epoch %d head %s batch: %d avg loss %f avg loss no "
                          "lamb %f time %s" % \
                          (config.model_ind, e_i, head, b_i, avg_loss_batch.item(),
                           avg_loss_no_lamb_batch.item(), datetime.now()))
                        sys.stdout.flush()

                    if not np.isfinite(avg_loss_batch.item()):
                        print("Loss is not finite... %s:" %
                              avg_loss_batch.item())
                        exit(1)

                    avg_loss += avg_loss_batch.item()
                    avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
                    avg_loss_count += 1

                    avg_loss_batch.backward()
                    optimiser.step()

                    if ((b_i % 50) == 0) and save_progression:
                        save_progress(config,
                                      net,
                                      mapping_assignment_dataloader,
                                      mapping_test_dataloader,
                                      save_progression_count,
                                      sobel=False,
                                      render_count=render_count)
                        save_progression_count += 1

                    b_i += 1
                    if b_i == 2 and config.test_code:
                        break

            avg_loss = float(avg_loss / avg_loss_count)
            avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

            epoch_loss.append(avg_loss)
            epoch_loss_no_lamb.append(avg_loss_no_lamb)

        # Eval
        # -----------------------------------------------------------------------

        sub_head = None
        if config.select_sub_head_on_loss:
            sub_head = get_subhead_using_loss(config,
                                              dataloaders_head_B,
                                              net,
                                              sobel=False,
                                              lamb=config.lamb_B)
        is_best = cluster_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=False,
            use_sub_head=sub_head)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        if config.double_eval:
            print("double eval: \n %s" % (nice(config.double_eval_stats[-1])))
        sys.stdout.flush()

        axarr[0].clear()
        axarr[0].plot(config.epoch_acc)
        axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc))

        axarr[1].clear()
        axarr[1].plot(config.epoch_avg_subhead_acc)
        axarr[1].set_title("acc (avg), top: %f" %
                           max(config.epoch_avg_subhead_acc))

        axarr[2].clear()
        axarr[2].plot(config.epoch_loss_head_A)
        axarr[2].set_title("Loss head A")

        axarr[3].clear()
        axarr[3].plot(config.epoch_loss_no_lamb_head_A)
        axarr[3].set_title("Loss no lamb head A")

        axarr[4].clear()
        axarr[4].plot(config.epoch_loss_head_B)
        axarr[4].set_title("Loss head B")

        axarr[5].clear()
        axarr[5].plot(config.epoch_loss_no_lamb_head_B)
        axarr[5].set_title("Loss no lamb head B")

        if config.double_eval:
            axarr[6].clear()
            axarr[6].plot(config.double_eval_acc)
            axarr[6].set_title("double eval acc (best), top: %f" %
                               max(config.double_eval_acc))

            axarr[7].clear()
            axarr[7].plot(config.double_eval_avg_subhead_acc)
            axarr[7].set_title("double eval acc (avg)), top: %f" %
                               max(config.double_eval_avg_subhead_acc))

        fig.tight_layout()
        fig.canvas.draw_idle()
        fig.savefig(os.path.join(config.out_dir, "plots.png"))

        if is_best or (e_i % config.save_freq == 0):
            net.module.cpu()

            if e_i % config.save_freq == 0:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "latest_net.pytorch"))
                torch.save(
                    optimiser.state_dict(),
                    os.path.join(config.out_dir, "latest_optimiser.pytorch"))

                config.last_epoch = e_i  # for last saved version

            if is_best:
                # also serves as backup if hardware fails - less likely to hit this
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "best_net.pytorch"))
                torch.save(
                    optimiser.state_dict(),
                    os.path.join(config.out_dir, "best_optimiser.pytorch"))

                with open(os.path.join(config.out_dir, "best_config.pickle"),
                          'wb') as outfile:
                    pickle.dump(config, outfile)

                with open(os.path.join(config.out_dir, "best_config.txt"),
                          "w") as text_file:
                    text_file.write("%s" % config)

            net.module.cuda()

        with open(os.path.join(config.out_dir, "config.pickle"),
                  'wb') as outfile:
            pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "config.txt"),
                  "w") as text_file:
            text_file.write("%s" % config)

        if config.test_code:
            exit(0)
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_ind", type=int, required=True)

    parser.add_argument("--arch", type=str, required=True)

    parser.add_argument("--head_lr", type=float, required=True)
    parser.add_argument("--trunk_lr", type=float, required=True)

    parser.add_argument("--num_epochs", type=int, default=3200)

    parser.add_argument("--new_batch_sz", type=int, default=-1)

    parser.add_argument("--old_model_ind", type=int, required=True)

    parser.add_argument("--penultimate_features",
                        default=False,
                        action="store_true")

    parser.add_argument("--random_affine", default=False, action="store_true")
    parser.add_argument("--affine_p", type=float, default=0.5)

    parser.add_argument("--cutout", default=False, action="store_true")
    parser.add_argument("--cutout_p", type=float, default=0.5)
    parser.add_argument("--cutout_max_box", type=float, default=0.5)

    parser.add_argument("--restart", default=False, action="store_true")
    parser.add_argument("--lr_schedule", type=int, nargs="+", default=[])
    parser.add_argument("--lr_mult", type=float, default=0.5)

    parser.add_argument("--restart_new_model_ind",
                        default=False,
                        action="store_true")
    parser.add_argument("--new_model_ind", type=int, default=0)

    parser.add_argument("--out_root",
                        type=str,
                        default="/scratch/shared/slow/xuji/iid_private")
    config = parser.parse_args()  # new config

    # Setup ----------------------------------------------------------------------

    config.contiguous_sz = 10  # Tencrop
    config.out_dir = os.path.join(config.out_root, str(config.model_ind))

    if not os.path.exists(config.out_dir):
        os.makedirs(config.out_dir)

    if config.restart:
        given_config = config
        reloaded_config_path = os.path.join(given_config.out_dir,
                                            "config.pickle")
        print("Loading restarting config from: %s" % reloaded_config_path)
        with open(reloaded_config_path, "rb") as config_f:
            config = pickle.load(config_f)
        assert (config.model_ind == given_config.model_ind)

        config.restart = True
        config.num_epochs = given_config.num_epochs  # train for longer

        config.restart_new_model_ind = given_config.restart_new_model_ind
        config.new_model_ind = given_config.new_model_ind

        start_epoch = config.last_epoch + 1

        print("...restarting from epoch %d" % start_epoch)

        # in case we overshot without saving
        config.epoch_acc = config.epoch_acc[:start_epoch]
        config.epoch_loss = config.epoch_loss[:start_epoch]

    else:
        config.epoch_acc = []
        config.epoch_loss = []
        start_epoch = 0

    # old config only used retrospectively for setting up model at start
    reloaded_config_path = os.path.join(
        os.path.join(config.out_root, str(config.old_model_ind)),
        "config.pickle")
    print("Loading old features config from: %s" % reloaded_config_path)
    with open(reloaded_config_path, "rb") as config_f:
        old_config = pickle.load(config_f)
        assert (old_config.model_ind == config.old_model_ind)

    if config.new_batch_sz == -1:
        config.new_batch_sz = old_config.batch_sz

    fig, axarr = plt.subplots(2, sharex=False, figsize=(20, 20))

    # Data -----------------------------------------------------------------------

    assert (old_config.dataset == "STL10")

    # make supervised data: train on train, test on test, unlabelled is unused
    tf1, tf2, tf3 = sobel_make_transforms(old_config,
                                          random_affine=config.random_affine,
                                          cutout=config.cutout,
                                          cutout_p=config.cutout_p,
                                          cutout_max_box=config.cutout_max_box,
                                          affine_p=config.affine_p)

    dataset_class = torchvision.datasets.STL10
    train_data = dataset_class(
        root=old_config.dataset_root,
        transform=tf2,  # also could use tf1
        split="train")

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.new_batch_sz,
                                               shuffle=True,
                                               num_workers=0,
                                               drop_last=False)

    test_data = dataset_class(root=old_config.dataset_root,
                              transform=None,
                              split="test")
    test_data = TenCropAndFinish(test_data,
                                 input_sz=old_config.input_sz,
                                 include_rgb=old_config.include_rgb)

    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=config.new_batch_sz,
        # full batch
        shuffle=False,
        num_workers=0,
        drop_last=False)

    # Model ----------------------------------------------------------------------

    net_features = archs.__dict__[old_config.arch](old_config)

    if not config.restart:
        model_path = os.path.join(old_config.out_dir, "best_net.pytorch")
        net_features.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))

    dlen = get_dlen(net_features,
                    train_loader,
                    include_rgb=old_config.include_rgb,
                    penultimate_features=config.penultimate_features)
    print("dlen: %d" % dlen)

    assert (config.arch == "SupHead5")
    net = SupHead5(net_features, dlen=dlen, gt_k=old_config.gt_k)

    if config.restart:
        print("restarting from latest net")
        model_path = os.path.join(config.out_dir, "latest_net.pytorch")
        net.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))

    net.cuda()
    net = torch.nn.DataParallel(net)

    opt_trunk = torch.optim.Adam(net.module.trunk.parameters(),
                                 lr=config.trunk_lr)
    opt_head = torch.optim.Adam(net.module.head.parameters(),
                                lr=(config.head_lr))

    if config.restart:
        print("restarting from latest optimiser")
        optimiser_states = torch.load(
            os.path.join(config.out_dir, "latest_optimiser.pytorch"))
        opt_trunk.load_state_dict(optimiser_states["opt_trunk"])
        opt_head.load_state_dict(optimiser_states["opt_head"])
    else:
        print("using new optimiser state")

    criterion = nn.CrossEntropyLoss().cuda()

    if not config.restart:
        net.eval()
        acc = assess_acc_block(
            net,
            test_loader,
            gt_k=old_config.gt_k,
            include_rgb=old_config.include_rgb,
            penultimate_features=config.penultimate_features,
            contiguous_sz=config.contiguous_sz)

        print("pre: model %d old model %d, acc %f time %s" %
              (config.model_ind, config.old_model_ind, acc, datetime.now()))
        sys.stdout.flush()

        config.epoch_acc.append(acc)

    if config.restart_new_model_ind:
        assert (config.restart)
        config.model_ind = config.new_model_ind  # old_model_ind stays same
        config.out_dir = os.path.join(config.out_root, str(config.model_ind))
        print("restarting as model %d" % config.model_ind)

        if not os.path.exists(config.out_dir):
            os.makedirs(config.out_dir)

    # Train ----------------------------------------------------------------------

    for e_i in xrange(start_epoch, config.num_epochs):
        net.train()

        if e_i in config.lr_schedule:
            print("e_i %d, multiplying lr for opt trunk and head by %f" %
                  (e_i, config.lr_mult))
            opt_trunk = update_lr(opt_trunk, lr_mult=config.lr_mult)
            opt_head = update_lr(opt_head, lr_mult=config.lr_mult)
            if not hasattr(config, "lr_changes"):
                config.lr_changes = []
            config.lr_changes.append((e_i, config.lr_mult))

        avg_loss = 0.
        num_batches = len(train_loader)
        for i, (imgs, targets) in enumerate(train_loader):
            imgs = sobel_process(imgs.cuda(), old_config.include_rgb)
            targets = targets.cuda()

            x_out = net(imgs, penultimate_features=config.penultimate_features)
            loss = criterion(x_out, targets)

            avg_loss += float(loss.data)

            opt_trunk.zero_grad()
            opt_head.zero_grad()

            loss.backward()

            opt_trunk.step()
            opt_head.step()

            if (i % 100 == 0) or (e_i == start_epoch):
                print("batch %d of %d, loss %f, time %s" %
                      (i, num_batches, float(loss.data), datetime.now()))
                sys.stdout.flush()

        avg_loss /= num_batches

        net.eval()
        acc = assess_acc_block(
            net,
            test_loader,
            gt_k=old_config.gt_k,
            include_rgb=old_config.include_rgb,
            penultimate_features=config.penultimate_features,
            contiguous_sz=config.contiguous_sz)

        print(
            "model %d old model %d epoch %d acc %f time %s" %
            (config.model_ind, config.old_model_ind, e_i, acc, datetime.now()))
        sys.stdout.flush()

        is_best = False
        if acc > max(config.epoch_acc):
            is_best = True

        config.epoch_acc.append(acc)
        config.epoch_loss.append(avg_loss)

        axarr[0].clear()
        axarr[0].plot(config.epoch_acc)
        axarr[0].set_title("Acc")

        axarr[1].clear()
        axarr[1].plot(config.epoch_loss)
        axarr[1].set_title("Loss")

        fig.canvas.draw_idle()
        fig.savefig(os.path.join(config.out_dir, "plots.png"))

        if is_best or (e_i % 10 == 0):
            net.module.cpu()

            if is_best:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "best_net.pytorch"))
                torch.save(
                    {
                        "opt_head": opt_head.state_dict(),
                        "opt_trunk": opt_trunk.state_dict()
                    }, os.path.join(config.out_dir, "best_optimiser.pytorch"))

            # save model sparingly for this script
            if e_i % 10 == 0:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "latest_net.pytorch"))
                torch.save(
                    {
                        "opt_head": opt_head.state_dict(),
                        "opt_trunk": opt_trunk.state_dict()
                    }, os.path.join(config.out_dir,
                                    "latest_optimiser.pytorch"))

            net.module.cuda()

            config.last_epoch = e_i  # for last saved version

        with open(os.path.join(config.out_dir, "config.pickle"),
                  'w') as outfile:
            pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "config.txt"),
                  "w") as text_file:
            text_file.write("%s" % config)
Ejemplo n.º 3
0
    next_epoch = 1

fig, axarr = plt.subplots(4, sharex=False, figsize=(20, 20))

# Train ------------------------------------------------------------------------

for e_i in xrange(next_epoch, config.num_epochs):
    print("Starting e_i: %d" % e_i)
    sys.stdout.flush()

    iterators = (d for d in dataloaders)

    b_i = 0

    if e_i in config.lr_schedule:
        optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

    avg_loss = 0.  # over epoch
    avg_loss_no_lamb = 0.
    avg_loss_count = 0

    for tup in itertools.izip(*iterators):
        net.module.zero_grad()

        all_imgs = torch.zeros(config.batch_sz, config.in_channels,
                               config.input_sz, config.input_sz).cuda()
        all_imgs_tf = torch.zeros(config.batch_sz, config.in_channels,
                                  config.input_sz, config.input_sz).cuda()

        imgs_curr = tup[0][0]  # always the first
        curr_batch_sz = imgs_curr.size(0)
Ejemplo n.º 4
0
def train():
  print("inside train")
  exit()
  dataloaders_head_A, mapping_assignment_dataloader, mapping_test_dataloader = \
    segmentation_create_dataloaders(config)
  dataloaders_head_B = dataloaders_head_A  # unlike for clustering datasets

  net = archs.__dict__[config.arch](config)
  if config.restart:
    dict = torch.load(os.path.join(config.out_dir, dict_name),
                      map_location=lambda storage, loc: storage)
    net.load_state_dict(dict["net"])
  net.cuda()
  net = torch.nn.DataParallel(net)
  net.train()

  optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr)
  if config.restart:
    optimiser.load_state_dict(dict["optimiser"])

  heads = ["A", "B"]
  if hasattr(config, "head_B_first") and config.head_B_first:
    heads = ["B", "A"]

  # Results
  # ----------------------------------------------------------------------

  if config.restart:
    next_epoch = config.last_epoch + 1
    print("starting from epoch %d" % next_epoch)

    config.epoch_acc = config.epoch_acc[:next_epoch]  # in case we overshot
    config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:next_epoch]
    config.epoch_stats = config.epoch_stats[:next_epoch]

    config.epoch_loss_head_A = config.epoch_loss_head_A[:(next_epoch - 1)]
    config.epoch_loss_no_lamb_head_A = config.epoch_loss_no_lamb_head_A[
                                       :(next_epoch - 1)]
    config.epoch_loss_head_B = config.epoch_loss_head_B[:(next_epoch - 1)]
    config.epoch_loss_no_lamb_head_B = config.epoch_loss_no_lamb_head_B[
                                       :(next_epoch - 1)]
  else:
    config.epoch_acc = []
    config.epoch_avg_subhead_acc = []
    config.epoch_stats = []

    config.epoch_loss_head_A = []
    config.epoch_loss_no_lamb_head_A = []

    config.epoch_loss_head_B = []
    config.epoch_loss_no_lamb_head_B = []

    _ = segmentation_eval(config, net,
                          mapping_assignment_dataloader=mapping_assignment_dataloader,
                          mapping_test_dataloader=mapping_test_dataloader,
                          sobel=(not config.no_sobel),
                          using_IR=config.using_IR)

    print(
      "Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1])))
    sys.stdout.flush()
    next_epoch = 1

  fig, axarr = plt.subplots(6, sharex=False, figsize=(20, 20))

  if not config.use_uncollapsed_loss:
    print("using condensed loss (default)")
    loss_fn = IID_segmentation_loss
  else:
    print("using uncollapsed loss!")
    loss_fn = IID_segmentation_loss_uncollapsed

  # Train
  # ------------------------------------------------------------------------

  for e_i in xrange(next_epoch, config.num_epochs):
    print("Starting e_i: %d %s" % (e_i, datetime.now()))
    sys.stdout.flush()

    if e_i in config.lr_schedule:
      optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

    for head_i in range(2):
      head = heads[head_i]
      if head == "A":
        dataloaders = dataloaders_head_A
        epoch_loss = config.epoch_loss_head_A
        epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A
        lamb = config.lamb_A

      elif head == "B":
        dataloaders = dataloaders_head_B
        epoch_loss = config.epoch_loss_head_B
        epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B
        lamb = config.lamb_B

      iterators = (d for d in dataloaders)
      b_i = 0
      avg_loss = 0.  # over heads and head_epochs (and sub_heads)
      avg_loss_no_lamb = 0.
      avg_loss_count = 0

      for tup in itertools.izip(*iterators):
        net.module.zero_grad()

        if not config.no_sobel:
          pre_channels = config.in_channels - 1
        else:
          pre_channels = config.in_channels

        all_img1 = torch.zeros(config.batch_sz, pre_channels,
                               config.input_sz, config.input_sz).to(
          torch.float32).cuda()
        all_img2 = torch.zeros(config.batch_sz, pre_channels,
                               config.input_sz, config.input_sz).to(
          torch.float32).cuda()
        all_affine2_to_1 = torch.zeros(config.batch_sz, 2, 3).to(
          torch.float32).cuda()
        all_mask_img1 = torch.zeros(config.batch_sz, config.input_sz,
                                    config.input_sz).to(torch.float32).cuda()

        curr_batch_sz = tup[0][0].shape[0]
        for d_i in xrange(config.num_dataloaders):
          img1, img2, affine2_to_1, mask_img1 = tup[d_i]
          assert (img1.shape[0] == curr_batch_sz)

          actual_batch_start = d_i * curr_batch_sz
          actual_batch_end = actual_batch_start + curr_batch_sz

          all_img1[actual_batch_start:actual_batch_end, :, :, :] = img1
          all_img2[actual_batch_start:actual_batch_end, :, :, :] = img2
          all_affine2_to_1[actual_batch_start:actual_batch_end, :,
          :] = affine2_to_1
          all_mask_img1[actual_batch_start:actual_batch_end, :, :] = mask_img1

        if not (curr_batch_sz == config.dataloader_batch_sz) and (
            e_i == next_epoch):
          print("last batch sz %d" % curr_batch_sz)

        curr_total_batch_sz = curr_batch_sz * config.num_dataloaders  # times 2
        all_img1 = all_img1[:curr_total_batch_sz, :, :, :]
        all_img2 = all_img2[:curr_total_batch_sz, :, :, :]
        all_affine2_to_1 = all_affine2_to_1[:curr_total_batch_sz, :, :]
        all_mask_img1 = all_mask_img1[:curr_total_batch_sz, :, :]

        if (not config.no_sobel):
          all_img1 = sobel_process(all_img1, config.include_rgb,
                                   using_IR=config.using_IR)
          all_img2 = sobel_process(all_img2, config.include_rgb,
                                   using_IR=config.using_IR)

        x1_outs = net(all_img1, head=head)
        x2_outs = net(all_img2, head=head)

        avg_loss_batch = None  # avg over the heads
        avg_loss_no_lamb_batch = None

        for i in xrange(config.num_sub_heads):
          loss, loss_no_lamb = loss_fn(x1_outs[i],
                                       x2_outs[i],
                                       all_affine2_to_1=all_affine2_to_1,
                                       all_mask_img1=all_mask_img1,
                                       lamb=lamb,
                                       half_T_side_dense=config.half_T_side_dense,
                                       half_T_side_sparse_min=config.half_T_side_sparse_min,
                                       half_T_side_sparse_max=config.half_T_side_sparse_max)

          if avg_loss_batch is None:
            avg_loss_batch = loss
            avg_loss_no_lamb_batch = loss_no_lamb
          else:
            avg_loss_batch += loss
            avg_loss_no_lamb_batch += loss_no_lamb

        avg_loss_batch /= config.num_sub_heads
        avg_loss_no_lamb_batch /= config.num_sub_heads

        if ((b_i % 100) == 0) or (e_i == next_epoch):
          print(
            "Model ind %d epoch %d head %s batch: %d avg loss %f avg loss no "
            "lamb %f "
            "time %s" % \
            (config.model_ind, e_i, head, b_i, avg_loss_batch.item(),
             avg_loss_no_lamb_batch.item(), datetime.now()))
          sys.stdout.flush()

        if not np.isfinite(avg_loss_batch.item()):
          print("Loss is not finite... %s:" % str(avg_loss_batch))
          exit(1)

        avg_loss += avg_loss_batch.item()
        avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
        avg_loss_count += 1

        avg_loss_batch.backward()
        optimiser.step()

        torch.cuda.empty_cache()

        b_i += 1
        if b_i == 2 and config.test_code:
          break

      avg_loss = float(avg_loss / avg_loss_count)
      avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

      epoch_loss.append(avg_loss)
      epoch_loss_no_lamb.append(avg_loss_no_lamb)

    # Eval
    # -----------------------------------------------------------------------

    is_best = segmentation_eval(config, net,
                                mapping_assignment_dataloader=mapping_assignment_dataloader,
                                mapping_test_dataloader=mapping_test_dataloader,
                                sobel=(
                                  not config.no_sobel),
                                using_IR=config.using_IR)

    print(
      "Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1])))
    sys.stdout.flush()

    axarr[0].clear()
    axarr[0].plot(config.epoch_acc)
    axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc))

    axarr[1].clear()
    axarr[1].plot(config.epoch_avg_subhead_acc)
    axarr[1].set_title("acc (avg), top: %f" % max(config.epoch_avg_subhead_acc))

    axarr[2].clear()
    axarr[2].plot(config.epoch_loss_head_A)
    axarr[2].set_title("Loss head A")

    axarr[3].clear()
    axarr[3].plot(config.epoch_loss_no_lamb_head_A)
    axarr[3].set_title("Loss no lamb head A")

    axarr[4].clear()
    axarr[4].plot(config.epoch_loss_head_B)
    axarr[4].set_title("Loss head B")

    axarr[5].clear()
    axarr[5].plot(config.epoch_loss_no_lamb_head_B)
    axarr[5].set_title("Loss no lamb head B")

    fig.canvas.draw_idle()
    fig.savefig(os.path.join(config.out_dir, "plots.png"))

    if is_best or (e_i % config.save_freq == 0):
      net.module.cpu()
      save_dict = {"net": net.module.state_dict(),
                   "optimiser": optimiser.state_dict()}

      if e_i % config.save_freq == 0:
        torch.save(save_dict, os.path.join(config.out_dir, "latest.pytorch"))
        config.last_epoch = e_i  # for last saved version

      if is_best:
        torch.save(save_dict, os.path.join(config.out_dir, "best.pytorch"))

        with open(os.path.join(config.out_dir, "best_config.pickle"),
                  'wb') as outfile:
          pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "best_config.txt"),
                  "w") as text_file:
          text_file.write("%s" % config)

      net.module.cuda()

    with open(os.path.join(config.out_dir, "config.pickle"), 'wb') as outfile:
      pickle.dump(config, outfile)

    with open(os.path.join(config.out_dir, "config.txt"), "w") as text_file:
      text_file.write("%s" % config)

    if config.test_code:
      exit(0)