예제 #1
0
def result_log(config, net, mapping_assignment_dataloader, mapping_test_dataloader):
    """Logs accuracies, losses, other per epoch stats and setting the loss function to be used

    Params:
      config: configuration for the training run
      net: PyTorch network
      mapping_assignment_dataloader: TODO
      mapping_test_dataloader: TODO

    Returns:
        [type] -- [description]
    """

    if config.restart:
        next_epoch = config.last_epoch + 1
        print("starting from epoch %d" % next_epoch)

        config.epoch_acc = config.epoch_acc[:next_epoch]  # in case we overshot
        config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:next_epoch]
        config.epoch_stats = config.epoch_stats[:next_epoch]

        config.epoch_loss_head_A = config.epoch_loss_head_A[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_A = config.epoch_loss_no_lamb_head_A[
            :(next_epoch - 1)]
        config.epoch_loss_head_B = config.epoch_loss_head_B[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_B = config.epoch_loss_no_lamb_head_B[
            :(next_epoch - 1)]
    else:
        config.epoch_acc = []
        config.epoch_avg_subhead_acc = []
        config.epoch_stats = []

        config.epoch_loss_head_A = []
        config.epoch_loss_no_lamb_head_A = []

        config.epoch_loss_head_B = []
        config.epoch_loss_no_lamb_head_B = []

        _ = segmentation_eval(config, net,
                              mapping_assignment_dataloader=mapping_assignment_dataloader,
                              mapping_test_dataloader=mapping_test_dataloader,
                              sobel=(not config.no_sobel),
                              using_IR=config.using_IR)

        print(
            "Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1])))
        sys.stdout.flush()
        next_epoch = 1

    fig, axarr = plt.subplots(6, sharex=False, figsize=(20, 20))

    if not config.use_uncollapsed_loss:
        print("using condensed loss (default)")
        loss_fn = IID_segmentation_loss
    else:
        print("using uncollapsed loss!")
        loss_fn = IID_segmentation_loss_uncollapsed

    return next_epoch, fig, axarr, loss_fn
예제 #2
0
def train():
    dataloaders_head_A, mapping_assignment_dataloader, mapping_test_dataloader = \
      segmentation_create_dataloaders(config)
    dataloaders_head_B = dataloaders_head_A  # unlike for clustering datasets

    net = archs.__dict__[config.arch](config)
    if config.restart:
        dict = torch.load(os.path.join(config.out_dir, dict_name),
                          map_location=lambda storage, loc: storage)
        net.load_state_dict(dict["net"])
    net.cuda()
    net = torch.nn.DataParallel(net)
    net.train()

    optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr)
    if config.restart:
        optimiser.load_state_dict(dict["optimiser"])

    heads = ["A", "B"]
    if hasattr(config, "head_B_first") and config.head_B_first:
        heads = ["B", "A"]

    # Results
    # ----------------------------------------------------------------------

    if config.restart:
        next_epoch = config.last_epoch + 1
        print("starting from epoch %d" % next_epoch)

        config.epoch_acc = config.epoch_acc[:next_epoch]  # in case we overshot
        config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:
                                                                    next_epoch]
        config.epoch_stats = config.epoch_stats[:next_epoch]

        config.epoch_loss_head_A = config.epoch_loss_head_A[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_A = config.epoch_loss_no_lamb_head_A[:(
            next_epoch - 1)]
        config.epoch_loss_head_B = config.epoch_loss_head_B[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_B = config.epoch_loss_no_lamb_head_B[:(
            next_epoch - 1)]
    else:
        config.epoch_acc = []
        config.epoch_avg_subhead_acc = []
        config.epoch_stats = []

        config.epoch_loss_head_A = []
        config.epoch_loss_no_lamb_head_A = []

        config.epoch_loss_head_B = []
        config.epoch_loss_no_lamb_head_B = []

        _ = segmentation_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=(not config.no_sobel),
            using_IR=config.using_IR)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        sys.stdout.flush()
        next_epoch = 1

    fig, axarr = plt.subplots(6, sharex=False, figsize=(20, 20))

    if not config.use_uncollapsed_loss:
        print("using condensed loss (default)")
        loss_fn = IID_segmentation_loss
    else:
        print("using uncollapsed loss!")
        loss_fn = IID_segmentation_loss_uncollapsed

    # Train
    # ------------------------------------------------------------------------

    for e_i in xrange(next_epoch, config.num_epochs):
        print("Starting e_i: %d %s" % (e_i, datetime.now()))
        sys.stdout.flush()

        if e_i in config.lr_schedule:
            optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

        for head_i in range(2):
            head = heads[head_i]
            if head == "A":
                dataloaders = dataloaders_head_A
                epoch_loss = config.epoch_loss_head_A
                epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A
                lamb = config.lamb_A

            elif head == "B":
                dataloaders = dataloaders_head_B
                epoch_loss = config.epoch_loss_head_B
                epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B
                lamb = config.lamb_B

            iterators = (d for d in dataloaders)
            b_i = 0
            avg_loss = 0.  # over heads and head_epochs (and subheads)
            avg_loss_no_lamb = 0.
            avg_loss_count = 0

            for tup in itertools.izip(*iterators):
                net.module.zero_grad()

                if not config.no_sobel:
                    pre_channels = config.in_channels - 1
                else:
                    pre_channels = config.in_channels

                all_img1 = torch.zeros(config.batch_sz, pre_channels,
                                       config.input_sz, config.input_sz).to(
                                           torch.float32).cuda()
                all_img2 = torch.zeros(config.batch_sz, pre_channels,
                                       config.input_sz, config.input_sz).to(
                                           torch.float32).cuda()
                all_affine2_to_1 = torch.zeros(config.batch_sz, 2,
                                               3).to(torch.float32).cuda()
                all_mask_img1 = torch.zeros(config.batch_sz, config.input_sz,
                                            config.input_sz).to(
                                                torch.float32).cuda()

                curr_batch_sz = tup[0][0].shape[0]
                for d_i in xrange(config.num_dataloaders):
                    img1, img2, affine2_to_1, mask_img1 = tup[d_i]
                    assert (img1.shape[0] == curr_batch_sz)

                    actual_batch_start = d_i * curr_batch_sz
                    actual_batch_end = actual_batch_start + curr_batch_sz

                    all_img1[
                        actual_batch_start:actual_batch_end, :, :, :] = img1
                    all_img2[
                        actual_batch_start:actual_batch_end, :, :, :] = img2
                    all_affine2_to_1[actual_batch_start:
                                     actual_batch_end, :, :] = affine2_to_1
                    all_mask_img1[
                        actual_batch_start:actual_batch_end, :, :] = mask_img1

                if not (curr_batch_sz
                        == config.dataloader_batch_sz) and (e_i == next_epoch):
                    print("last batch sz %d" % curr_batch_sz)

                curr_total_batch_sz = curr_batch_sz * config.num_dataloaders  # times 2
                all_img1 = all_img1[:curr_total_batch_sz, :, :, :]
                all_img2 = all_img2[:curr_total_batch_sz, :, :, :]
                all_affine2_to_1 = all_affine2_to_1[:curr_total_batch_sz, :, :]
                all_mask_img1 = all_mask_img1[:curr_total_batch_sz, :, :]

                if (not config.no_sobel):
                    all_img1 = sobel_process(all_img1,
                                             config.include_rgb,
                                             using_IR=config.using_IR)
                    all_img2 = sobel_process(all_img2,
                                             config.include_rgb,
                                             using_IR=config.using_IR)

                x1_outs = net(all_img1, head=head)
                x2_outs = net(all_img2, head=head)

                avg_loss_batch = None  # avg over the heads
                avg_loss_no_lamb_batch = None

                for i in xrange(config.num_subheads):
                    loss, loss_no_lamb = loss_fn(
                        x1_outs[i],
                        x2_outs[i],
                        all_affine2_to_1=all_affine2_to_1,
                        all_mask_img1=all_mask_img1,
                        lamb=lamb,
                        half_T_side_dense=config.half_T_side_dense,
                        half_T_side_sparse_min=config.half_T_side_sparse_min,
                        half_T_side_sparse_max=config.half_T_side_sparse_max)

                    if avg_loss_batch is None:
                        avg_loss_batch = loss
                        avg_loss_no_lamb_batch = loss_no_lamb
                    else:
                        avg_loss_batch += loss
                        avg_loss_no_lamb_batch += loss_no_lamb

                avg_loss_batch /= config.num_subheads
                avg_loss_no_lamb_batch /= config.num_subheads

                if ((b_i % 100) == 0) or (e_i == next_epoch):
                    print(
                      "Model ind %d epoch %d head %s batch: %d avg loss %f avg loss no "
                      "lamb %f "
                      "time %s" % \
                      (config.model_ind, e_i, head, b_i, avg_loss_batch.item(),
                       avg_loss_no_lamb_batch.item(), datetime.now()))
                    sys.stdout.flush()

                if not np.isfinite(avg_loss_batch.item()):
                    print("Loss is not finite... %s:" % str(avg_loss_batch))
                    exit(1)

                avg_loss += avg_loss_batch.item()
                avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
                avg_loss_count += 1

                avg_loss_batch.backward()
                optimiser.step()

                torch.cuda.empty_cache()

                b_i += 1
                if b_i == 2 and config.test_code:
                    break

            avg_loss = float(avg_loss / avg_loss_count)
            avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

            epoch_loss.append(avg_loss)
            epoch_loss_no_lamb.append(avg_loss_no_lamb)

        # Eval
        # -----------------------------------------------------------------------

        is_best = segmentation_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=(not config.no_sobel),
            using_IR=config.using_IR)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        sys.stdout.flush()

        axarr[0].clear()
        axarr[0].plot(config.epoch_acc)
        axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc))

        axarr[1].clear()
        axarr[1].plot(config.epoch_avg_subhead_acc)
        axarr[1].set_title("acc (avg), top: %f" %
                           max(config.epoch_avg_subhead_acc))

        axarr[2].clear()
        axarr[2].plot(config.epoch_loss_head_A)
        axarr[2].set_title("Loss head A")

        axarr[3].clear()
        axarr[3].plot(config.epoch_loss_no_lamb_head_A)
        axarr[3].set_title("Loss no lamb head A")

        axarr[4].clear()
        axarr[4].plot(config.epoch_loss_head_B)
        axarr[4].set_title("Loss head B")

        axarr[5].clear()
        axarr[5].plot(config.epoch_loss_no_lamb_head_B)
        axarr[5].set_title("Loss no lamb head B")

        fig.canvas.draw_idle()
        fig.savefig(os.path.join(config.out_dir, "plots.png"))

        if is_best or (e_i % config.save_freq == 0):
            net.module.cpu()
            save_dict = {
                "net": net.module.state_dict(),
                "optimiser": optimiser.state_dict()
            }

            if e_i % config.save_freq == 0:
                torch.save(save_dict,
                           os.path.join(config.out_dir, "latest.pytorch"))
                config.last_epoch = e_i  # for last saved version

            if is_best:
                torch.save(save_dict,
                           os.path.join(config.out_dir, "best.pytorch"))

                with open(os.path.join(config.out_dir, "best_config.pickle"),
                          'wb') as outfile:
                    pickle.dump(config, outfile)

                with open(os.path.join(config.out_dir, "best_config.txt"),
                          "w") as text_file:
                    text_file.write("%s" % config)

            net.module.cuda()

        with open(os.path.join(config.out_dir, "config.pickle"),
                  'wb') as outfile:
            pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "config.txt"),
                  "w") as text_file:
            text_file.write("%s" % config)

        if config.test_code:
            exit(0)
예제 #3
0
)
net.load_state_dict(net_state)
net.cuda()
net = torch.nn.DataParallel(net)

stats_dict = segmentation_eval(
    old_config,
    net,
    mapping_assignment_dataloader=mapping_assignment_dataloader,
    mapping_test_dataloader=mapping_test_dataloader,
    sobel=(not old_config.no_sobel),
    using_IR=old_config.using_IR,
    return_only=True,
)
assert isinstance(stats_dict, dict)

acc = stats_dict["best"]

config.epoch_stats = [stats_dict]
config.epoch_acc = [acc]
config.epoch_avg_subhead_acc = stats_dict["avg"]

print("Time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1])))
sys.stdout.flush()

with open(os.path.join(config.out_dir, "config.pickle"), "wb") as outfile:
    pickle.dump(config, outfile)

with open(os.path.join(config.out_dir, "config.txt"), "w") as text_file:
    text_file.write("%s" % config)
    config.epoch_loss_no_lamb = config.epoch_loss_no_lamb[:(next_epoch - 1)]
else:
    config.epoch_acc = []
    config.epoch_avg_subhead_acc = []
    config.epoch_stats = []

    config.epoch_loss = []
    config.epoch_loss_no_lamb = []

    _ = cluster_eval(config, net,
                     mapping_assignment_dataloader=mapping_assignment_dataloader,
                     mapping_test_dataloader=mapping_test_dataloader,
                     sobel=True)

    print("Pre: time %s: \n %s" %
          (datetime.now(), nice(config.epoch_stats[-1])))
    sys.stdout.flush()
    next_epoch = 1

fig, axarr = plt.subplots(4, sharex=False, figsize=(20, 20))

# Train ------------------------------------------------------------------------

for e_i in range(next_epoch, config.num_epochs):
    print("Starting e_i: %d" % e_i)
    sys.stdout.flush()

    iterators = (d for d in dataloaders)

    b_i = 0
    if e_i in config.lr_schedule:
예제 #5
0
def train(render_count=-1):
    dataloaders_head_A, dataloaders_head_B, \
    mapping_assignment_dataloader, mapping_test_dataloader = \
        cluster_twohead_create_dataloaders(config)

    net = archs.__dict__[config.arch](config)
    if config.restart:
        model_path = os.path.join(config.out_dir, net_name)
        print("Model path: %s" % model_path)
        net.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))

    net.cuda()
    net = torch.nn.DataParallel(net)
    net.train()

    optimiser = get_opt(config.opt)(net.module.parameters(), lr=config.lr)
    if config.restart:
        if not (given_config is not None and given_config.num_epochs == 0):
            print("loading latest opt")
            optimiser.load_state_dict(
                torch.load(os.path.join(config.out_dir, opt_name)))

    heads = ["B", "A"]
    if config.head_A_first:
        heads = ["A", "B"]

    head_epochs = {}
    head_epochs["A"] = config.head_A_epochs
    head_epochs["B"] = config.head_B_epochs

    # Results
    # ----------------------------------------------------------------------

    if config.restart:
        if not config.restart_from_best:
            next_epoch = config.last_epoch + 1  # corresponds to last saved model
        else:
            next_epoch = np.argmax(np.array(config.epoch_acc)) + 1
        print("starting from epoch %d" % next_epoch)

        # in case we overshot without saving
        config.epoch_acc = config.epoch_acc[:next_epoch]  # in case we overshot
        config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:
                                                                    next_epoch]
        config.epoch_stats = config.epoch_stats[:next_epoch]

        if config.double_eval:
            config.double_eval_acc = config.double_eval_acc[:next_epoch]
            config.double_eval_avg_subhead_acc = config.double_eval_avg_subhead_acc[:
                                                                                    next_epoch]
            config.double_eval_stats = config.double_eval_stats[:next_epoch]

        config.epoch_loss_head_A = config.epoch_loss_head_A[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_A = config.epoch_loss_no_lamb_head_A[:(
            next_epoch - 1)]

        config.epoch_loss_head_B = config.epoch_loss_head_B[:(next_epoch - 1)]
        config.epoch_loss_no_lamb_head_B = config.epoch_loss_no_lamb_head_B[:(
            next_epoch - 1)]
    else:
        config.epoch_acc = []
        config.epoch_avg_subhead_acc = []
        config.epoch_stats = []

        if config.double_eval:
            config.double_eval_acc = []
            config.double_eval_avg_subhead_acc = []
            config.double_eval_stats = []

        config.epoch_loss_head_A = []
        config.epoch_loss_no_lamb_head_A = []

        config.epoch_loss_head_B = []
        config.epoch_loss_no_lamb_head_B = []

        subhead = None
        if config.select_subhead_on_loss:
            subhead = get_subhead_using_loss(config,
                                             dataloaders_head_B,
                                             net,
                                             sobel=False,
                                             lamb=config.lamb_B)
        _ = cluster_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=False,
            use_subhead=subhead)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        if config.double_eval:
            print("double eval: \n %s" % (nice(config.double_eval_stats[-1])))
        sys.stdout.flush()
        next_epoch = 1

    fig, axarr = plt.subplots(6 + 2 * int(config.double_eval),
                              sharex=False,
                              figsize=(20, 20))

    save_progression = hasattr(config, "save_progression") and \
                       config.save_progression
    if save_progression:
        save_progression_count = 0
        save_progress(config,
                      net,
                      mapping_assignment_dataloader,
                      mapping_test_dataloader,
                      save_progression_count,
                      sobel=False,
                      render_count=render_count)
        save_progression_count += 1

    # Train
    # ------------------------------------------------------------------------

    for e_i in xrange(next_epoch, config.num_epochs):
        print("Starting e_i: %d" % e_i)

        if e_i in config.lr_schedule:
            optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

        for head_i in range(2):
            head = heads[head_i]
            if head == "A":
                dataloaders = dataloaders_head_A
                epoch_loss = config.epoch_loss_head_A
                epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A
                lamb = config.lamb_A
            elif head == "B":
                dataloaders = dataloaders_head_B
                epoch_loss = config.epoch_loss_head_B
                epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B
                lamb = config.lamb_B

            avg_loss = 0.  # over heads and head_epochs (and subheads)
            avg_loss_no_lamb = 0.
            avg_loss_count = 0

            for head_i_epoch in range(head_epochs[head]):
                sys.stdout.flush()

                iterators = (d for d in dataloaders)

                b_i = 0
                for tup in itertools.izip(*iterators):
                    net.module.zero_grad()

                    all_imgs = torch.zeros(
                        (config.batch_sz, config.in_channels, config.input_sz,
                         config.input_sz)).cuda()
                    all_imgs_tf = torch.zeros(
                        (config.batch_sz, config.in_channels, config.input_sz,
                         config.input_sz)).cuda()

                    imgs_curr = tup[0][0]  # always the first
                    curr_batch_sz = imgs_curr.size(0)
                    for d_i in xrange(config.num_dataloaders):
                        imgs_tf_curr = tup[1 + d_i][0]  # from 2nd to last
                        assert (curr_batch_sz == imgs_tf_curr.size(0))

                        actual_batch_start = d_i * curr_batch_sz
                        actual_batch_end = actual_batch_start + curr_batch_sz
                        all_imgs[actual_batch_start:actual_batch_end, :, :, :] = \
                            imgs_curr.cuda()
                        all_imgs_tf[actual_batch_start:actual_batch_end, :, :, :] = \
                            imgs_tf_curr.cuda()

                    if not (curr_batch_sz == config.dataloader_batch_sz):
                        print("last batch sz %d" % curr_batch_sz)

                    curr_total_batch_sz = curr_batch_sz * config.num_dataloaders  #
                    # times 2
                    all_imgs = all_imgs[:curr_total_batch_sz, :, :, :]
                    all_imgs_tf = all_imgs_tf[:curr_total_batch_sz, :, :, :]

                    x_outs = net(all_imgs)
                    x_tf_outs = net(all_imgs_tf)

                    avg_loss_batch = None  # avg over the heads
                    avg_loss_no_lamb_batch = None
                    for i in xrange(config.num_subheads):
                        loss, loss_no_lamb = IID_loss(x_outs[i],
                                                      x_tf_outs[i],
                                                      lamb=lamb)
                        if avg_loss_batch is None:
                            avg_loss_batch = loss
                            avg_loss_no_lamb_batch = loss_no_lamb
                        else:
                            avg_loss_batch += loss
                            avg_loss_no_lamb_batch += loss_no_lamb

                    avg_loss_batch /= config.num_subheads
                    avg_loss_no_lamb_batch /= config.num_subheads

                    if ((b_i % 100) == 0) or (e_i == next_epoch):
                        print(
                            "Model ind %d epoch %d head %s batch: %d avg loss %f avg loss no lamb %f time %s" % \
                            (config.model_ind, e_i, head, b_i, avg_loss_batch.item(), avg_loss_no_lamb_batch.item(),
                             datetime.now()))
                        sys.stdout.flush()

                    if not np.isfinite(avg_loss_batch.item()):
                        print("Loss is not finite... %s:" %
                              avg_loss_batch.item())
                        exit(1)

                    avg_loss += avg_loss_batch.item()
                    avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
                    avg_loss_count += 1

                    avg_loss_batch.backward()
                    optimiser.step()

                    if ((b_i % 50) == 0) and save_progression:
                        save_progress(config,
                                      net,
                                      mapping_assignment_dataloader,
                                      mapping_test_dataloader,
                                      save_progression_count,
                                      sobel=False,
                                      render_count=render_count)
                        save_progression_count += 1

                    b_i += 1
                    if b_i == 2 and config.test_code:
                        break

            avg_loss = float(avg_loss / avg_loss_count)
            avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

            epoch_loss.append(avg_loss)
            epoch_loss_no_lamb.append(avg_loss_no_lamb)

        # Eval
        # -----------------------------------------------------------------------

        subhead = None
        if config.select_subhead_on_loss:
            subhead = get_subhead_using_loss(config,
                                             dataloaders_head_B,
                                             net,
                                             sobel=False,
                                             lamb=config.lamb_B)
        is_best = cluster_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=False,
            use_subhead=subhead)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        if config.double_eval:
            print("double eval: \n %s" % (nice(config.double_eval_stats[-1])))
        sys.stdout.flush()

        axarr[0].clear()
        axarr[0].plot(config.epoch_acc)
        axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc))

        axarr[1].clear()
        axarr[1].plot(config.epoch_avg_subhead_acc)
        axarr[1].set_title("acc (avg), top: %f" %
                           max(config.epoch_avg_subhead_acc))

        axarr[2].clear()
        axarr[2].plot(config.epoch_loss_head_A)
        axarr[2].set_title("Loss head A")

        axarr[3].clear()
        axarr[3].plot(config.epoch_loss_no_lamb_head_A)
        axarr[3].set_title("Loss no lamb head A")

        axarr[4].clear()
        axarr[4].plot(config.epoch_loss_head_B)
        axarr[4].set_title("Loss head B")

        axarr[5].clear()
        axarr[5].plot(config.epoch_loss_no_lamb_head_B)
        axarr[5].set_title("Loss no lamb head B")

        if config.double_eval:
            axarr[6].clear()
            axarr[6].plot(config.double_eval_acc)
            axarr[6].set_title("double eval acc (best), top: %f" %
                               max(config.double_eval_acc))

            axarr[7].clear()
            axarr[7].plot(config.double_eval_avg_subhead_acc)
            axarr[7].set_title("double eval acc (avg)), top: %f" %
                               max(config.double_eval_avg_subhead_acc))

        fig.tight_layout()
        fig.canvas.draw_idle()
        fig.savefig(os.path.join(config.out_dir, "plots.png"))

        if is_best or (e_i % config.save_freq == 0):
            net.module.cpu()

            if e_i % config.save_freq == 0:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "latest_net.pytorch"))
                torch.save(
                    optimiser.state_dict(),
                    os.path.join(config.out_dir, "latest_optimiser.pytorch"))

                config.last_epoch = e_i  # for last saved version

            if is_best:
                # also serves as backup if hardware fails - less likely to hit this
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "best_net.pytorch"))
                torch.save(
                    optimiser.state_dict(),
                    os.path.join(config.out_dir, "best_optimiser.pytorch"))

                with open(os.path.join(config.out_dir, "best_config.pickle"),
                          'wb') as outfile:
                    pickle.dump(config, outfile)

                with open(os.path.join(config.out_dir, "best_config.txt"),
                          "w") as text_file:
                    text_file.write("%s" % config)

            net.module.cuda()

        with open(os.path.join(config.out_dir, "config.pickle"),
                  'wb') as outfile:
            pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "config.txt"),
                  "w") as text_file:
            text_file.write("%s" % config)

        if config.test_code:
            exit(0)
예제 #6
0
def evaluation(config, net, optimiser, mapping_assignment_dataloader, mapping_test_dataloader, fig, axarr, current_epoch):
    """Evaluates and logs results from the net and model checkpointing

    Params:
      config: TODO
      net: TODO
      optimiser: TODO
      mapping_assignment_dataloader: TODO
      mapping_test_dataloader: TODO
      fig: TODO
      axarr: TODO
      current_epoch: TODO
    """

    is_best = segmentation_eval(config, net,
                                mapping_assignment_dataloader=mapping_assignment_dataloader,
                                mapping_test_dataloader=mapping_test_dataloader,
                                sobel=(
                                    not config.no_sobel),
                                using_IR=config.using_IR)

    print(
        "Pre: time %s: \n %s" % (datetime.now(), nice(config.epoch_stats[-1])))
    sys.stdout.flush()

    axarr[0].clear()
    axarr[0].plot(config.epoch_acc)
    axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc))

    axarr[1].clear()
    axarr[1].plot(config.epoch_avg_subhead_acc)
    axarr[1].set_title("acc (avg), top: %f" %
                       max(config.epoch_avg_subhead_acc))

    axarr[2].clear()
    axarr[2].plot(config.epoch_loss_head_A)
    axarr[2].set_title("Loss head A")

    axarr[3].clear()
    axarr[3].plot(config.epoch_loss_no_lamb_head_A)
    axarr[3].set_title("Loss no lamb head A")

    axarr[4].clear()
    axarr[4].plot(config.epoch_loss_head_B)
    axarr[4].set_title("Loss head B")

    axarr[5].clear()
    axarr[5].plot(config.epoch_loss_no_lamb_head_B)
    axarr[5].set_title("Loss no lamb head B")

    fig.canvas.draw_idle()
    fig.savefig(os.path.join(config.out_dir, "plots.png"))

    if is_best or (current_epoch % config.save_freq == 0):
        net.module.cpu()
        save_dict = {"net": net.module.state_dict(),
                     "optimiser": optimiser.state_dict()}

        if current_epoch % config.save_freq == 0:
            torch.save(save_dict, os.path.join(
                config.out_dir, "latest.pytorch"))
            config.last_epoch = current_epoch  # for last saved version

        if is_best:
            torch.save(save_dict, os.path.join(config.out_dir, "best.pytorch"))

            with open(os.path.join(config.out_dir, "best_config.pickle"),
                      'wb') as outfile:
                pickle.dump(config, outfile)

            with open(os.path.join(config.out_dir, "best_config.txt"),
                      "w") as text_file:
                text_file.write("%s" % config)

        net.module.cuda() if not config.nocuda else None

    with open(os.path.join(config.out_dir, "config.pickle"), 'wb') as outfile:
        pickle.dump(config, outfile)

    with open(os.path.join(config.out_dir, "config.txt"), "w") as text_file:
        text_file.write("%s" % config)

    if config.test_code:
        exit(0)
예제 #7
0
    if config.select_subhead_on_loss:
        subhead = get_subhead_using_loss(config,
                                         dataloaders_head_B,
                                         net,
                                         sobel=True,
                                         lamb=config.lamb)
    _ = cluster_eval(
        config,
        net,
        mapping_assignment_dataloader=mapping_assignment_dataloader,
        mapping_test_dataloader=mapping_test_dataloader,
        sobel=True,
        use_subhead=subhead)

    print("Pre: time %s: \n %s" %
          (datetime.now(), nice(config.epoch_stats[-1])))
    if config.double_eval:
        print("double eval: \n %s" % (nice(config.double_eval_stats[-1])))
    sys.stdout.flush()
    next_epoch = 1

fig, axarr = plt.subplots(6 + 2 * int(config.double_eval),
                          sharex=False,
                          figsize=(20, 20))

# Train ------------------------------------------------------------------------

for e_i in xrange(next_epoch, config.num_epochs):
    print("Starting e_i: %d" % (e_i))

    if e_i in config.lr_schedule:
예제 #8
0
파일: cluster.py 프로젝트: hendraet/IIC
def train(config, net, optimiser, render_count=-1):
    # TODO: center crop ok or does it remove too much information if text is aligned left
    dataloader_list, mapping_assignment_dataloader, mapping_test_dataloader = get_dataloader_list(
        config)

    num_heads = len(config.output_ks)

    # Results ----------------------------------------------------------------------

    if config.restart:
        if not config.restart_from_best:
            next_epoch = config.last_epoch + 1  # corresponds to last saved model
        else:
            next_epoch = np.argmax(np.array(config.epoch_acc)) + 1
        print("starting from epoch %d" % next_epoch)

        # in case we overshot without saving
        config.epoch_acc = config.epoch_acc[:next_epoch]  # in case we overshot
        config.epoch_avg_subhead_acc = config.epoch_avg_subhead_acc[:
                                                                    next_epoch]
        config.epoch_stats = config.epoch_stats[:next_epoch]

        if config.double_eval:
            config.double_eval_acc = config.double_eval_acc[:next_epoch]
            config.double_eval_avg_subhead_acc = config.double_eval_avg_subhead_acc[:
                                                                                    next_epoch]
            config.double_eval_stats = config.double_eval_stats[:next_epoch]

        for i, loss in enumerate(config.epoch_loss):
            config.epoch_loss[i] = loss[:(next_epoch - 1)]
        for i, loss_no_lamb in enumerate(config.epoch_loss_no_lamb):
            config.epoch_loss_no_lamb[i] = loss_no_lamb[:(next_epoch - 1)]

    else:
        config.epoch_acc = []
        config.epoch_avg_subhead_acc = []
        config.epoch_stats = []

        if config.double_eval:
            config.double_eval_acc = []
            config.double_eval_avg_subhead_acc = []
            config.double_eval_stats = []

        config.epoch_loss = [[] for _ in range(num_heads)]
        config.epoch_loss_no_lamb = [[] for _ in range(num_heads)]

        subhead = None
        if config.select_subhead_on_loss:
            assert num_heads == 2
            subhead = get_subhead_using_loss(config,
                                             dataloader_list[1],
                                             net,
                                             sobel=config.sobel,
                                             lamb=config.lamb)

        _ = cluster_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=config.sobel,
            use_subhead=subhead)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        if config.double_eval:
            print("double eval: \n %s" % (nice(config.double_eval_stats[-1])))
        sys.stdout.flush()
        next_epoch = 1

    fig, axarr = plt.subplots(2 + 2 * num_heads + 2 * int(config.double_eval),
                              sharex=False,
                              figsize=(20, 20))

    save_progression = hasattr(config,
                               "save_progression") and config.save_progression
    if save_progression:
        save_progression_count = 0
        save_progress(config,
                      net,
                      mapping_assignment_dataloader,
                      mapping_test_dataloader,
                      save_progression_count,
                      sobel=config.sobel,
                      render_count=render_count)
        save_progression_count += 1

    # Train ------------------------------------------------------------------------

    heads = range(num_heads)
    if config.reverse_heads:
        heads = reversed(heads)

    for e_i in xrange(next_epoch, config.num_epochs + 1):
        print("Starting e_i: %d" % e_i)

        if e_i in config.lr_schedule:
            optimiser = update_lr(optimiser, lr_mult=config.lr_mult)

        for head_idx in heads:
            dataloaders = dataloader_list[head_idx]
            epoch_loss = config.epoch_loss[head_idx]
            epoch_loss_no_lamb = config.epoch_loss_no_lamb[head_idx]

            avg_loss = 0.  # over heads and head_epochs (and subheads)
            avg_loss_no_lamb = 0.
            avg_loss_count = 0

            for head_i_epoch in range(config.head_epochs[head_idx]):
                sys.stdout.flush()

                iterators = (d for d in dataloaders)

                b_i = 0
                for tup in itertools.izip(*iterators):
                    net.module.zero_grad()

                    in_channels = config.in_channels
                    if config.sobel:
                        # one less because this is before sobel
                        in_channels -= 1
                    all_imgs = torch.zeros(
                        (config.batch_sz, in_channels, config.input_sz[0],
                         config.input_sz[1])).cuda()
                    all_imgs_tf = torch.zeros(
                        (config.batch_sz, in_channels, config.input_sz[0],
                         config.input_sz[1])).cuda()

                    imgs_curr = tup[0][0]  # always the first
                    curr_batch_sz = imgs_curr.size(0)
                    for d_i in xrange(config.num_dataloaders):
                        imgs_tf_curr = tup[1 + d_i][0]  # from 2nd to last
                        assert (curr_batch_sz == imgs_tf_curr.size(0))

                        actual_batch_start = d_i * curr_batch_sz
                        actual_batch_end = actual_batch_start + curr_batch_sz
                        all_imgs[actual_batch_start:
                                 actual_batch_end, :, :, :] = imgs_curr.cuda()
                        all_imgs_tf[
                            actual_batch_start:
                            actual_batch_end, :, :, :] = imgs_tf_curr.cuda()

                    if not (curr_batch_sz == config.dataloader_batch_sz):
                        print("last batch sz %d" % curr_batch_sz)

                    curr_total_batch_sz = curr_batch_sz * config.num_dataloaders
                    all_imgs = all_imgs[:curr_total_batch_sz, :, :, :]
                    all_imgs_tf = all_imgs_tf[:curr_total_batch_sz, :, :, :]

                    if config.sobel:
                        all_imgs = sobel_process(all_imgs, config.include_rgb)
                        all_imgs_tf = sobel_process(all_imgs_tf,
                                                    config.include_rgb)

                    x_outs = net(all_imgs, head_idx=head_idx)
                    x_tf_outs = net(all_imgs_tf, head_idx=head_idx)

                    avg_loss_batch = None  # avg over the subheads
                    avg_loss_no_lamb_batch = None
                    for i in xrange(config.num_subheads):
                        loss, loss_no_lamb = IID_loss(x_outs[i],
                                                      x_tf_outs[i],
                                                      lamb=config.lamb)
                        if avg_loss_batch is None:
                            avg_loss_batch = loss
                            avg_loss_no_lamb_batch = loss_no_lamb
                        else:
                            avg_loss_batch += loss
                            avg_loss_no_lamb_batch += loss_no_lamb

                    avg_loss_batch /= config.num_subheads
                    avg_loss_no_lamb_batch /= config.num_subheads

                    if ((b_i % 100) == 0) or (e_i == next_epoch and b_i < 10):
                        print(
                            "Model ind %d epoch %d head %s head_i_epoch %d batch %d: avg loss %f avg loss no lamb %f "
                            "time %s" %
                            (config.model_ind, e_i, str(head_idx),
                             head_i_epoch, b_i, avg_loss_batch.item(),
                             avg_loss_no_lamb_batch.item(), datetime.now()))
                        sys.stdout.flush()

                    if not np.isfinite(avg_loss_batch.item()):
                        print("Loss is not finite... %s:" %
                              str(avg_loss_batch))
                        sys.exit(1)

                    avg_loss += avg_loss_batch.item()
                    avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
                    avg_loss_count += 1

                    avg_loss_batch.backward()
                    optimiser.step()

                    if ((b_i % 50) == 0) and save_progression:
                        save_progress(config,
                                      net,
                                      mapping_assignment_dataloader,
                                      mapping_test_dataloader,
                                      save_progression_count,
                                      sobel=False,
                                      render_count=render_count)
                        save_progression_count += 1

                    b_i += 1
                    if b_i == 2 and config.test_code:
                        break

            avg_loss = float(avg_loss / avg_loss_count)
            avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

            epoch_loss.append(avg_loss)
            epoch_loss_no_lamb.append(avg_loss_no_lamb)

        # Eval -----------------------------------------------------------------------

        # Can also pick the subhead using the evaluation process (to do this, set use_subhead=None)
        subhead = None
        if config.select_subhead_on_loss:
            assert num_heads == 2
            subhead = get_subhead_using_loss(config,
                                             dataloader_list[1],
                                             net,
                                             sobel=config.sobel,
                                             lamb=config.lamb)

        is_best = cluster_eval(
            config,
            net,
            mapping_assignment_dataloader=mapping_assignment_dataloader,
            mapping_test_dataloader=mapping_test_dataloader,
            sobel=config.sobel,
            use_subhead=subhead)

        print("Pre: time %s: \n %s" %
              (datetime.now(), nice(config.epoch_stats[-1])))
        if config.double_eval:
            print("double eval: \n %s" % (nice(config.double_eval_stats[-1])))
        sys.stdout.flush()

        axarr[0].clear()
        axarr[0].plot(config.epoch_acc)
        axarr[0].set_title("acc (best), top: %f" % max(config.epoch_acc))

        axarr[1].clear()
        axarr[1].plot(config.epoch_avg_subhead_acc)
        axarr[1].set_title("acc (avg), top: %f" %
                           max(config.epoch_avg_subhead_acc))

        last_ax_idx = 1
        starting_ax_idx = last_ax_idx + 1
        for i in range(num_heads):
            axarr[starting_ax_idx + i].clear()
            axarr[starting_ax_idx + i].plot(config.epoch_loss[i])
            axarr[starting_ax_idx + i].set_title("Loss head_idx " + str(i))

            axarr[starting_ax_idx + i + 1].clear()
            axarr[starting_ax_idx + i + 1].plot(config.epoch_loss_no_lamb[i])
            axarr[starting_ax_idx + i + 1].set_title("Loss no lamb head_idx " +
                                                     str(i))

        if config.double_eval:
            next_index = starting_ax_idx + 2 * num_heads
            axarr[next_index].clear()
            axarr[next_index].plot(config.double_eval_acc)
            axarr[next_index].set_title("double eval acc (best), top: %f" %
                                        max(config.double_eval_acc))

            axarr[next_index + 1].clear()
            axarr[next_index + 1].plot(config.double_eval_avg_subhead_acc)
            axarr[next_index +
                  1].set_title("double eval acc (avg), top: %f" %
                               max(config.double_eval_avg_subhead_acc))

        fig.tight_layout()
        fig.canvas.draw_idle()
        fig.savefig(os.path.join(config.out_dir, "plots.png"))

        if is_best or (e_i % config.save_freq == 0):
            net.module.cpu()

            if e_i % config.save_freq == 0:
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "latest_net.pytorch"))
                torch.save(
                    optimiser.state_dict(),
                    os.path.join(config.out_dir, "latest_optimiser.pytorch"))
                config.last_epoch = e_i  # for last saved version

            if is_best:
                # also serves as backup if hardware fails - less likely to hit this
                torch.save(net.module.state_dict(),
                           os.path.join(config.out_dir, "best_net.pytorch"))
                torch.save(
                    optimiser.state_dict(),
                    os.path.join(config.out_dir, "best_optimiser.pytorch"))

                with open(os.path.join(config.out_dir, "best_config.pickle"),
                          'wb') as outfile:
                    pickle.dump(config, outfile)

                with open(os.path.join(config.out_dir, "best_config.txt"),
                          "w") as text_file:
                    text_file.write("%s" % config)

            net.module.cuda()

        with open(os.path.join(config.out_dir, "config.pickle"),
                  'wb') as outfile:
            pickle.dump(config, outfile)

        with open(os.path.join(config.out_dir, "config.txt"),
                  "w") as text_file:
            text_file.write("%s" % config)

        if config.test_code:
            exit(0)