def train(config):

    # Fix random seed and unable deterministic calcualtions
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True

    ## In case of highly imbalanced classes one can add weights inside Cross-entropy loss
    ## made by: [1 - (x / sum(nSamples)) for x in nSamples]
    #weights = [0.936,0.064]
    #class_weights = torch.FloatTensor(weights).cuda()

    # Set up summary writer
    writer = SummaryWriter(config['output_path'])

    class_num = config["network"]["params"]["class_num"]
    class_criterion = nn.CrossEntropyLoss(
    )  # optionally add "weight=class_weights" in case of higly imbalanced classes
    transfer_criterion = config["loss"]["name"]
    center_criterion = config["loss"]["discriminant_loss"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])
    loss_params = config["loss"]

    # Prepare image data. Image shuffling is fixed with the random seed choice.
    # Train:validation:test = 70:10:20
    dsets = {}
    dset_loaders = {}

    pristine_indices = torch.randperm(len(pristine_x))
    # Train sample
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    # Validation sample --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    # Test sample for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    # Train sample
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    # Validation sample --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    # Test sample for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    # Set number of epochs, and logging intervals
    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    # Print the configuration you are using
    config["out_file"].write("config: {}\n".format(config))
    config["out_file"].flush()

    # Set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    # Set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    #Loading trained model if we want:
    if config["ckpt_path"] is not None:
        print('load model from {}'.format(config['ckpt_path']))
        ckpt = torch.load(config['ckpt_path'] + '/best_model.pth.tar')
        base_network.load_state_dict(ckpt['base_network'])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    # Collect parameters for the chosen network to be trained
    if "DeepMerge" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
    elif net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 10,
            'decay_mult': 2
        }]

    # Class weights in case we need them, hewe we have balanced sample so weights are 1.0
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()

    parameter_list.append({
        "params": center_criterion.parameters(),
        "lr_mult": 10,
        'decay_mult': 1
    })

    # Set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))

    # Set learning rate scheduler
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    scan_lr = []
    scan_loss = []

    ###################
    ###### TRAIN ######
    ###################
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    len_valid_source = len(dset_loaders["source_valid"])
    len_valid_target = len(dset_loaders["target_valid"])

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):

        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \
                    base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = distance_classification_test(dset_loaders, 'source', \
                    base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose = False, save_where = None)
            else:
                raise ValueError("no test method for cls loss: {}".format(
                    config['loss']['ly_type']))

            snapshot_obj = {
                'epoch': i / len(dset_loaders["source"]),
                "base_network": base_network.state_dict(),
                'valid accuracy': temp_acc,
                'train accuracy': train_acc,
            }

            snapshot_obj['center_criterion'] = center_criterion.state_dict()

            if (i + 1) % config["snapshot_interval"] == 0:
                torch.save(
                    snapshot_obj,
                    osp.join(
                        config["output_path"], "epoch_{}_model.pth.tar".format(
                            i / len(dset_loaders["source"]))))

            if temp_acc > best_acc:
                best_acc = temp_acc
                # Save best model
                torch.save(
                    snapshot_obj,
                    osp.join(config["output_path"], "best_model.pth.tar"))

            # Write to log file
            log_str = "epoch: {}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(
                i / len(dset_loaders["source"]), config['loss']['ly_type'],
                temp_acc, config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc,
                              i / len(dset_loaders["source"]))
            writer.add_scalar("training accuracy", train_acc,
                              i / len(dset_loaders["source"]))

        ## Train one iteration
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()
        scan_lr.append(optim['param_groups'][0]['lr'])

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
            inputs_target, labels_target = iter(dset_loaders["target"]).next()

        except StopIteration:
            iter(dset_loaders["source"])
            iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        # Distance type. We use cosine.
        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        # Transfer loss - MMD
        transfer_loss = transfer_criterion(features[:source_batch_size],
                                           features[source_batch_size:])

        # Source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())

        # Final loss in case we do not want to add Fisher loss and Entropy minimization
        if config["fisher_or_no"] == 'no':
            total_loss = loss_params["trade_off"] * transfer_loss \
            + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            #################
            # Plot embeddings periodically. tSNE plots
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 20 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     num_of_samples=2000,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            #################

            optimizer.step()

            if i % config["log_iter"] == 0:

                # In case we want to do a learning rate scane to find best lr_cycle lengh:
                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                # In case we want to visualize gradients:
                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                # Logging:
                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                    ))

                config['out_file'].flush()

                # Logging for tensorboard
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))

                #################
                # Validation step
                #################
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        # Distance type. We use cosine.
                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        # Transfer loss - MMD
                        transfer_loss = transfer_criterion(
                            features[:valid_source_batch_size],
                            features[valid_source_batch_size:])

                        # Source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        # Final loss in case we do not want to add Fisher loss and Entropy minimization
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                    + classifier_loss

                    # Logging:
                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                            ))
                        config['out_file'].flush()
                        # Logging for tensorboard:
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))

                        # Early stop in case we see overfitting
                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

        # In case we want to add Fisher loss and Entropy minimizaiton
        else:
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=loss_params["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])

            # Entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # Final loss is the sum of all losses
            total_loss = loss_params["trade_off"] * transfer_loss \
                 + fisher_loss \
                 + loss_params["em_loss_coef"] * em_loss \
                 + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            #################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 20 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     num_of_samples=2000,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            #################

            if center_grad is not None:
                # Clear mmc_loss
                center_criterion.centers.grad.zero_()
                # Manually assign centers gradients other than using autograd
                center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                # In case we want to do a learning rate scane to find best lr_cycle lengh:
                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                # In case we want to visualize gradients:
                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                # Logging
                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}, '
                    'train entropy min loss={:0.4f}, '
                    'train fisher loss={:0.4f}, train intra-group fisher loss={:0.4f}, train inter-group fisher loss={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        em_loss.data.cpu().float().item(),
                        fisher_loss.cpu().float().item(),
                        fisher_intra_loss.cpu().float().item(),
                        fisher_inter_loss.cpu().float().item(),
                    ))
                config['out_file'].flush()
                # Logging for tensorboard
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training entropy minimization loss",
                                  em_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training total fisher loss",
                                  fisher_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training intra-group fisher",
                                  fisher_intra_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training inter-group fisher",
                                  fisher_inter_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))

                #################
                # Validation step
                #################
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()

                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        # Distance type. We use cosine.
                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        # Transfer loss - MMD
                        transfer_loss = transfer_criterion(
                            features[:valid_source_batch_size],
                            features[valid_source_batch_size:])

                        # Source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())
                        # Fisher loss on labeled source domain
                        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                            features.narrow(0, 0,
                                            int(valid_inputs.size(0) / 2)),
                            labels_valid_source,
                            inter_class=loss_params["inter_type"],
                            intra_loss_weight=loss_params["intra_loss_coef"],
                            inter_loss_weight=loss_params["inter_loss_coef"])

                        # Entropy minimization loss
                        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

                        # Final loss
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                     + fisher_loss \
                                     + loss_params["em_loss_coef"] * em_loss \
                                     + classifier_loss

                    # Logging
                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}, '
                            'valid entropy min loss={:0.4f}, '
                            'valid fisher loss={:0.4f}, valid intra-group fisher loss={:0.4f}, valid inter-group fisher loss={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                em_loss.data.cpu().float().item(),
                                fisher_loss.cpu().float().item(),
                                fisher_intra_loss.cpu().float().item(),
                                fisher_inter_loss.cpu().float().item(),
                            ))
                        config['out_file'].flush()
                        # Logging for tensorboard
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation total fisher loss",
                            fisher_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation intra-group fisher",
                            fisher_intra_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation inter-group fisher",
                            fisher_inter_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))

                        # Early stop in case we see overfitting
                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at epoch {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

    return best_acc
def test(config):
    
    # Fix random seed and unable deterministic calcualtions
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    torch.backends.cudnn.enabled=False
    torch.backends.cudnn.deterministic=True


    # Prepare image data. Image shuffling is fixed with the random seed choice.
    # Train:validation:test = 70:10:20
    dsets = {}
    dset_loaders = {}


    pristine_indices = torch.randperm(len(pristine_x))
    # Test sample for evaluation file
    pristine_x_test = pristine_x[pristine_indices[int(np.floor(.8*len(pristine_x))):]]
    pristine_y_test = pristine_y[pristine_indices[int(np.floor(.8*len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    # Test sample for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8*len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8*len(noisy_x))):]]

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"], batch_size = 64, shuffle = True, num_workers = 1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"], batch_size = 64, shuffle = True, num_workers = 1)

    class_num = config["network"]["params"]["class_num"]

    # Load checkpoint
    print('load model from {}'.format(config['ckpt_path']))
    ckpt = torch.load(config['ckpt_path']+'/best_model.pth.tar')
    print('recorded best training accuracy: {:0.4f} at epoch {}'.format(ckpt["train accuracy"], ckpt["epoch"]))
    print('recorded best validation accuracy: {:04f} at epoch {}'.format(ckpt["valid accuracy"], ckpt["epoch"]))

    train_accuracy = ckpt["train accuracy"]
    valid_accuracy = ckpt["valid accuracy"]
    
    # Set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network.load_state_dict(ckpt['base_network'])

    centroids = None
    if 'center_criterion' in ckpt.keys():
        centroids = ckpt['center_criterion']['centers'].cpu()
    target_centroids = None
    if 'target_center_criterion' in ckpt.keys():
        target_centroids = ckpt['target_center_criterion']['centers'].cpu()

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ###################
    ###### TEST ######
    ###################
    print("start test: ")
    base_network.train(False)
    if config["ly_type"] == 'cosine':
        source_test_acc, source_test_confusion_matrix = image_classification_test(dset_loaders, "source_test", \
            base_network, gpu=use_gpu, verbose = True, save_where = config['ckpt_path'])
        target_test_acc, target_test_confusion_matrix = image_classification_test(dset_loaders, "target_test", \
            base_network, gpu=use_gpu, verbose = True, save_where = config['ckpt_path'])

    elif config["ly_type"] == "euclidean":
        eval_centroids = None
        if centroids is not None:
            eval_centroids = centroids
        if target_centroids is not None:
            eval_centroids = target_centroids
        source_test_acc, source_test_confusion_matrix = distance_classification_test(dset_loaders, "source_test", \
            base_network, eval_centroids, gpu=use_gpu, verbose = True, save_where = config['ckpt_path'])
        target_test_acc, target_test_confusion_matrix = distance_classification_test(dset_loaders, "target_test", \
            base_network, eval_centroids, gpu=use_gpu, verbose = True, save_where = config['ckpt_path'])

    # Save train/test accuracy as pkl file
    with open(os.path.join(config["output_path"], 'accuracy.pkl'), 'wb') as pkl_file:
        pkl.dump({'train': train_accuracy, 'valid': valid_accuracy, 'source test': source_test_acc, 'target test': target_test_acc}, pkl_file)
    
    # Logging
    np.set_printoptions(precision=2)
    log_str = "train accuracy: {:.5f}\tvalid accuracy: {:5f}\nsource test accuracy: {:.5f}\nsource confusion matrix:\n{}\ntarget test accuracy: {:.5f}\ntarget confusion matrix:\n{}\n".format(
        train_accuracy, valid_accuracy, source_test_acc, source_test_confusion_matrix, target_test_acc, target_test_confusion_matrix)
    config["out_file"].write(log_str)
    config["out_file"].flush()
    print(log_str)

    return (source_test_acc, target_test_acc)
示例#3
0
def test(config):
    ## set pre-process
    # prep_dict = {}
    # prep_config = config["prep"]
    # prep_dict["source"] = prep.image_train( \
    #                         resize_size=prep_config["resize_size"], \
    #                         crop_size=prep_config["crop_size"])
    # prep_dict["target"] = prep.image_train( \
    #                         resize_size=prep_config["resize_size"], \
    #                         crop_size=prep_config["crop_size"])
    # if prep_config["test_10crop"]:
    #     prep_dict["test"] = prep.image_test_10crop( \
    #                         resize_size=prep_config["resize_size"], \
    #                         crop_size=prep_config["crop_size"])
    # else:
    #     prep_dict["test"] = prep.image_test( \
    #                         resize_size=prep_config["resize_size"], \
    #                         crop_size=prep_config["crop_size"])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    #data_config = config["data"]

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[pristine_indices[:int(np.floor(.7*len(pristine_x)))]]
    pristine_y_train = pristine_y[pristine_indices[:int(np.floor(.7*len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[int(np.floor(.7*len(pristine_x))) : int(np.floor(.8*len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[int(np.floor(.7*len(pristine_x))) : int(np.floor(.8*len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[pristine_indices[int(np.floor(.8*len(pristine_x))):]]
    pristine_y_test = pristine_y[pristine_indices[int(np.floor(.8*len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    #train
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7*len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7*len(noisy_x)))]]
    #validate --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7*len(noisy_x))) : int(np.floor(.8*len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7*len(noisy_x))) : int(np.floor(.8*len(noisy_x)))]]
    #test for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8*len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8*len(noisy_x))):]]


    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)


    # data_config = config["data"]
    # dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), ratio=prep_config["source_size"]), \
    #                             transform=prep_dict["source"])

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size = 36, shuffle = True, num_workers = 1)
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size = 36, shuffle = True, num_workers = 1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"], batch_size = 4, shuffle = True, num_workers = 1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"], batch_size = 4, shuffle = True, num_workers = 1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"], batch_size = 4, shuffle = True, num_workers = 1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"], batch_size = 4, shuffle = True, num_workers = 1)


    # dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
    #                             transform=prep_dict["source"])
    # dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
    #         batch_size=data_config["source"]["batch_size"], \
    #         shuffle=True, num_workers=2)
    # dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
    #                             transform=prep_dict["target"])
    # dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
    #         batch_size=data_config["target"]["batch_size"], \
    #         shuffle=True, num_workers=2)

    # if prep_config["test_10crop"]:
    #     for i in range(10):
    #         dsets["test"+str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
    #                             transform=prep_dict["test"]["val"+str(i)])
    #         dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
    #                             batch_size=data_config["test"]["batch_size"], \
    #                             shuffle=False, num_workers=2)

    #         dsets["target"+str(i)] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
    #                             transform=prep_dict["test"]["val"+str(i)])
    #         dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
    #                             batch_size=data_config["test"]["batch_size"], \
    #                             shuffle=False, num_workers=2)
    # else:
    #     dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
    #                             transform=prep_dict["test"])
    #     dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
    #                             batch_size=data_config["test"]["batch_size"], \
    #                             shuffle=False, num_workers=2)

    #     dsets["target_test"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
    #                             transform=prep_dict["test"])
    #     dset_loaders["target_test"] = util_data.DataLoader(dsets["target_test"], \
    #                             batch_size=data_config["test"]["batch_size"], \
    #                             shuffle=False, num_workers=2)

    class_num = config["network"]["params"]["class_num"]

    # load checkpoint
    print('load model from {}'.format(config['ckpt_path']))
    # load in an old way
    # base_network = torch.load(config["ckpt_path"])[0]
    # recommended practice
    ckpt = torch.load(config['ckpt_path'])
    print('recorded best training accuracy: {:0.4f} at step {}'.format(ckpt["train accuracy"], ckpt["step"]))
    print('recorded best validation accuracy: {:04f} at step {}'.format(ckpt["valid accuracy"], ckpt["step"]))
    train_accuracy = ckpt["train accuracy"]
    valid_accuracy = ckpt["valid accuracy"]
    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network.load_state_dict(ckpt['base_network'])

    centroids = None
    if 'center_criterion' in ckpt.keys():
        centroids = ckpt['center_criterion']['centers'].cpu()
    target_centroids = None
    if 'target_center_criterion' in ckpt.keys():
        target_centroids = ckpt['target_center_criterion']['centers'].cpu()

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## test
    print("start test: ")
    base_network.train(False)
    if config["ly_type"] == 'cosine':
        test_acc, test_confusion_matrix = image_classification_test(dset_loaders, str(config["domain"]), \
            base_network, \
            gpu=use_gpu)
    elif config["ly_type"] == "euclidean":
        eval_centroids = None
        if centroids is not None:
            eval_centroids = centroids
        if target_centroids is not None:
            eval_centroids = target_centroids
        test_acc, test_confusion_matrix = distance_classification_test(dset_loaders, str(config["domain"]), \
            base_network, eval_centroids, \
            gpu=use_gpu)

    # save train/test accuracy as pkl file
    with open(os.path.join(config["output_path"], 'accuracy.pkl'), 'wb') as pkl_file:
        pkl.dump({'train': train_accuracy, 'valid': valid_accuracy, 'test': test_acc}, pkl_file)
    
    np.set_printoptions(precision=2)
    log_str = "train accuracy: {:.5f}\tvalid accuracy: {:5f}\ttest accuracy: {:.5f}\nconfusion matrix:\n{}\n".format(
        train_accuracy, valid_accuracy, test_acc, test_confusion_matrix)
    config["out_file"].write(log_str)
    config["out_file"].flush()
    print(log_str)

    return test_acc
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])
    class_num = config["network"]["params"]["class_num"]
    class_criterion = nn.CrossEntropyLoss()
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    #train
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    #validate --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    #test for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)

    config['out_file'].write("dataset sizes: source={}\n".format(
        len(dsets["source"])))

    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    #print the configuration you are using
    config["out_file"].write("config: {}\n".format(config))
    config["out_file"].flush()

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if "DeepMerge" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
    elif net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 10,
            'decay_mult': 2
        }]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()

    parameter_list.append({
        "params": class_criterion.parameters(),
        "lr_mult": 10,
        'decay_mult': 1
    })

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    scan_lr = []
    scan_loss = []

    ## train
    len_train_source = len(dset_loaders["source"])
    len_valid_source = len(dset_loaders["source_valid"])

    classifier_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
            elif config['loss']['ly_type'] == 'euclidean':
                print(
                    'You cannot use the euclidean distance loss because it involves the target domain'
                )
            else:
                raise ValueError("no test method for cls loss: {}".format(
                    config['loss']['ly_type']))

            snapshot_obj = {
                'epoch': i / len(dset_loaders["source"]),
                "base_network": base_network.state_dict(),
                'valid accuracy': temp_acc,
                'train accuracy': train_acc,
            }

            snapshot_obj['class_criterion'] = class_criterion.state_dict()

            if (i + 1) % config["snapshot_interval"] == 0:
                torch.save(
                    snapshot_obj,
                    osp.join(
                        config["output_path"], "epoch_{}_model.pth.tar".format(
                            i / len(dset_loaders["source"]))))

            if temp_acc > best_acc:
                best_acc = temp_acc
                # save best model
                torch.save(
                    snapshot_obj,
                    osp.join(config["output_path"], "best_model.pth.tar"))
            log_str = "epoch: {}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(
                i / len(dset_loaders["source"]), config['loss']['ly_type'],
                temp_acc, config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc,
                              i / len(dset_loaders["source"]))
            writer.add_scalar("training accuracy", train_acc,
                              i / len(dset_loaders["source"]))

        ## train one iter
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()
        scan_lr.append(optim['param_groups'][0]['lr'])

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
        except StopIteration:
            iter(dset_loaders["source"])

        if use_gpu:
            inputs_source, labels_source = Variable(
                inputs_source).cuda(), Variable(labels_source).cuda()
        else:
            inputs_source, labels_source = Variable(inputs_source), Variable(
                labels_source)

        inputs = inputs_source
        source_batch_size = inputs_source.size(0)

        features, logits = base_network(inputs)
        source_logits = logits.narrow(0, 0, source_batch_size)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())

        total_loss = classifier_loss

        scan_loss.append(total_loss.cpu().float().item())

        total_loss.backward()

        ######################################
        # Plot embeddings periodically.
        if args.blobs is not None and i / len(
                dset_loaders["source"]) % 50 == 0:
            visualizePerformance(
                base_network,
                dset_loaders["source"],
                dset_loaders["target"],
                batch_size=128,
                num_of_samples=100,
                imgName='embedding_' + str(i / len(dset_loaders["source"])),
                save_dir=osp.join(config["output_path"], "blobs"))
        ##########################################

        optimizer.step()

        if i % config["log_iter"] == 0:

            if config['lr_scan'] != 'no':
                if not osp.exists(
                        osp.join(config["output_path"], "learning_rate_scan")):
                    os.makedirs(
                        osp.join(config["output_path"], "learning_rate_scan"))

                plot_learning_rate_scan(
                    scan_lr, scan_loss, i / len(dset_loaders["source"]),
                    osp.join(config["output_path"], "learning_rate_scan"))

            if config['grad_vis'] != 'no':
                if not osp.exists(osp.join(config["output_path"],
                                           "gradients")):
                    os.makedirs(osp.join(config["output_path"], "gradients"))

                plot_grad_flow(osp.join(config["output_path"], "gradients"),
                               i / len(dset_loaders["source"]),
                               base_network.named_parameters())

            config['out_file'].write('epoch {}: train total loss={:0.4f}, train classifier loss={:0.4f}\n'.format(i/len(dset_loaders["source"]), \
                total_loss.data.cpu(), classifier_loss.data.cpu().float().item(),))
            config['out_file'].flush()
            writer.add_scalar("training total loss",
                              total_loss.data.cpu().float().item(),
                              i / len(dset_loaders["source"]))
            writer.add_scalar("training classifier loss",
                              classifier_loss.data.cpu().float().item(),
                              i / len(dset_loaders["source"]))

            #attempted validation step
            for j in range(0, len(dset_loaders["source_valid"])):
                base_network.train(False)
                with torch.no_grad():

                    try:
                        inputs_source, labels_source = iter(
                            dset_loaders["source_valid"]).next()
                    except StopIteration:
                        iter(dset_loaders["source_valid"])

                    if use_gpu:
                        inputs_source, labels_source = Variable(
                            inputs_source).cuda(), Variable(
                                labels_source).cuda()
                    else:
                        inputs_source, labels_source = Variable(
                            inputs_source), Variable(labels_source)

                    inputs = inputs_source
                    source_batch_size = inputs_source.size(0)

                    features, logits = base_network(inputs)
                    source_logits = logits.narrow(0, 0, source_batch_size)

                    # source domain classification task loss
                    classifier_loss = class_criterion(source_logits,
                                                      labels_source.long())

                    # final loss
                    total_loss = classifier_loss
                    #total_loss.backward() no backprop on the eval mode

                if j % len(dset_loaders["source_valid"]) == 0:
                    config['out_file'].write('epoch {}: valid total loss={:0.4f}, valid classifier loss={:0.4f}\n'.format(i/len(dset_loaders["source"]), \
                        total_loss.data.cpu(), classifier_loss.data.cpu().float().item(),))
                    config['out_file'].flush()
                    writer.add_scalar("validation total loss",
                                      total_loss.data.cpu().float().item(),
                                      i / len(dset_loaders["source"]))
                    writer.add_scalar(
                        "validation classifier loss",
                        classifier_loss.data.cpu().float().item(),
                        i / len(dset_loaders["source"]))

                    if early_stop_engine.is_stop_training(
                            classifier_loss.cpu().float().item()):
                        config["out_file"].write(
                            "overfitting after {}, stop training at epoch {}\n"
                            .format(config["early_stop_patience"],
                                    i / len(dset_loaders["source"])))

                        sys.exit()

    return best_acc
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    class_num = config["network"]["params"]["class_num"]

    class_criterion = nn.CrossEntropyLoss()

    # transfer_criterion = config["loss"]["name"]
    #center_criterion = config["loss"]["discriminant_loss"](num_classes=class_num, feat_dim=config["network"]["params"]["bottleneck_dim"])
    
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[pristine_indices[:int(np.floor(.7*len(pristine_x)))]]
    pristine_y_train = pristine_y[pristine_indices[:int(np.floor(.7*len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[int(np.floor(.7*len(pristine_x))) : int(np.floor(.8*len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[int(np.floor(.7*len(pristine_x))) : int(np.floor(.8*len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[pristine_indices[int(np.floor(.8*len(pristine_x))):]]
    pristine_y_test = pristine_y[pristine_indices[int(np.floor(.8*len(pristine_x))):]]

    # noisy_indices = torch.randperm(len(noisy_x))
    # #train
    # noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7*len(noisy_x)))]]
    # noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7*len(noisy_x)))]]
    # #validate --- gets passed into test functions in train file
    # noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7*len(noisy_x))) : int(np.floor(.8*len(noisy_x)))]]
    # noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7*len(noisy_x))) : int(np.floor(.8*len(noisy_x)))]]
    # #test for evaluation file
    # noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8*len(noisy_x))):]]
    # noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8*len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    # dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    # dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    # dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)


    # data_config = config["data"]
    # dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), ratio=prep_config["source_size"]), \
    #                             transform=prep_dict["source"])

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size = 36, shuffle = True, num_workers = 1)
    # dset_loaders["target"] = DataLoader(dsets["target"], batch_size = 36, shuffle = True, num_workers = 1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"], batch_size = 4, shuffle = True, num_workers = 1)
    # dset_loaders["target_valid"] = DataLoader(dsets["target_valid"], batch_size = 4, shuffle = True, num_workers = 1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"], batch_size = 4, shuffle = True, num_workers = 1)
    # dset_loaders["target_test"] = DataLoader(dsets["target_test"], batch_size = 4, shuffle = True, num_workers = 1)

    config['out_file'].write("dataset sizes: source={}\n".format(
        len(dsets["source"]))) #TODO: change this too

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])


    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{"params":base_network.parameters(), "lr_mult":1, 'decay_mult':2}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()
    #parameter_list.append({"params":center_criterion.parameters(), "lr_mult": 10, 'decay_mult':1})
 
    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]


    ## train   
    len_train_source = len(dset_loaders["source"]) - 1
    # len_train_target = len(dset_loaders["target"]) - 1
    len_valid_source = len(dset_loaders["source_valid"]) - 1
    # len_valid_target = len(dset_loaders["target_valid"]) - 1

    classifier_loss_value = 0.0
    # transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, \
                    gpu=use_gpu)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, \
                    gpu=use_gpu)
            # elif config['loss']['ly_type'] == "euclidean":
            #     temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \
            #         base_network, center_criterion.centers.detach(), \
            #         gpu=use_gpu)
            #     train_acc, _ = distance_classification_test(dset_loaders, 'source', \
            #         base_network, \
            #         gpu=use_gpu)
            else:
                raise ValueError("no test method for cls loss: {}".format(config['loss']['ly_type']))
            
            snapshot_obj = {'step': i, 
                            "base_network": base_network.state_dict(), 
                            'valid accuracy': temp_acc,
                            'train accuracy' : train_acc,
                            }
            # snapshot_obj['center_criterion'] = center_criterion.state_dict()
            if temp_acc > best_acc:
                best_acc = temp_acc
                # save best model
                torch.save(snapshot_obj, 
                           osp.join(config["output_path"], "best_model.pth.tar"))
            log_str = "iter: {:05d}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(i, config['loss']['ly_type'], temp_acc, config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc, i)
            writer.add_scalar("training accuracy", train_acc, i)

            if early_stop_engine.is_stop_training(temp_acc):
                config["out_file"].write("no improvement after {}, stop training at step {}\n".format(
                    config["early_stop_patience"], i))
                # config["out_file"].write("finish training! \n")
                break

        if (i+1) % config["snapshot_interval"] == 0:
            torch.save(snapshot_obj, 
                        osp.join(config["output_path"], "iter_{:05d}_model.pth.tar".format(i)))
                    

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        # if i % len_train_target == 0:
        #     iter_target = iter(dset_loaders["target"])

        try:
            inputs_source, labels_source = iter_source.next()
            # inputs_target, labels_target = iter_target.next()
        except StopIteration:
            iter_source = iter(dset_loaders["source"])
            # iter_target = iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, labels_source = Variable(inputs_source).cuda(), Variable(labels_source).cuda()
        else:
            inputs_source, labels_source = Variable(inputs_source), Variable(labels_source)
           
        inputs = inputs_source
        source_batch_size = inputs_source.size(0)

        # if config['loss']['ly_type'] == 'cosine':
        features, logits = base_network(inputs)
        source_logits = logits.narrow(0, 0, source_batch_size)
        # elif config['loss']['ly_type'] == 'euclidean':
        #     features, _ = base_network(inputs)
        #     logits = -1.0 * loss.distance_to_centroids(features, center_criterion.centers.detach())
        #     source_logits = logits.narrow(0, 0, source_batch_size)

        # transfer_loss = transfer_criterion(features[:source_batch_size], features[source_batch_size:])

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())
        # fisher loss on labeled source domain
        # fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(features.narrow(0, 0, int(inputs.size(0)/2)), labels_source, inter_class=loss_params["inter_type"], 
                                                                               # intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"])
        # entropy minimization loss
        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

        total_loss = loss_params["em_loss_coef"] * em_loss + classifier_loss
   
        total_loss.backward() #we need to fix this

        # if center_grad is not None:
        #     # clear mmc_loss
        #     center_criterion.centers.grad.zero_()
        #     # Manually assign centers gradients other than using autograd
        #     center_criterion.centers.backward(center_grad)

        optimizer.step()

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: train total loss={:0.4f}, train classifier loss={:0.4f}, '
                'train entropy min loss={:0.4f}\n'.format(
                i, total_loss.data.cpu(), classifier_loss.data.cpu().float().item(), em_loss.data.cpu().float().item(), 
                ))
            config['out_file'].flush()
            writer.add_scalar("training total loss", total_loss.data.cpu().float().item(), i)
            writer.add_scalar("training classifier loss", classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("training entropy minimization loss", em_loss.data.cpu().float().item(), i)
            # writer.add_scalar("training transfer loss", transfer_loss.data.cpu().float().item(), i)
            # writer.add_scalar("training total fisher loss", fisher_loss.data.cpu().float().item(), i)
            # writer.add_scalar("training intra-group fisher", fisher_intra_loss.data.cpu().float().item(), i)
            # writer.add_scalar("training inter-group fisher", fisher_inter_loss.data.cpu().float().item(), i)

        #attempted validation step
        #if i < len_valid_source:
        base_network.eval()
        with torch.no_grad():
            if i % len_valid_source == 0:
                iter_source = iter(dset_loaders["source_valid"])
            # if i % len_valid_target == 0:
            #     iter_target = iter(dset_loaders["target_valid"])

            try:
                inputs_source, labels_source = iter_source.next()
                # inputs_target, labels_target = iter_target.next()

            except StopIteration:
                iter_source = iter(dset_loaders["source_valid"])
                # iter_target = iter(dset_loaders["target_valid"])

            if use_gpu:
                inputs_source, labels_source = Variable(inputs_source).cuda(), Variable(labels_source).cuda()
            else:
                inputs_source,labels_source = Variable(inputs_source), Variable(labels_source)
               
            inputs = inputs_source
            source_batch_size = inputs_source.size(0)

            # if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
            # elif config['loss']['ly_type'] == 'euclidean':
            #     features, _ = base_network(inputs)
            #     logits = -1.0 * loss.distance_to_centroids(features, center_criterion.centers.detach())
            #     source_logits = logits.narrow(0, 0, source_batch_size)

            # transfer_loss = transfer_criterion(features[:source_batch_size], features[source_batch_size:])

            # source domain classification task loss
            classifier_loss = class_criterion(source_logits, labels_source.long())
            # fisher loss on labeled source domain
            # fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(features.narrow(0, 0, int(inputs.size(0)/2)), labels_source, inter_class=loss_params["inter_type"], 
                                                                                   # intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))
            
            # final loss
            total_loss = loss_params["em_loss_coef"] * em_loss + classifier_loss
            #total_loss.backward() no backprop on the eval mode

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: validation total loss={:0.4f}, validation classifier loss={:0.4f}, '
                'validation entropy min loss={:0.4f}\n'.format(
                i, total_loss.data.cpu(), classifier_loss.data.cpu().float().item(), em_loss.data.cpu().float().item(), 
                ))
            config['out_file'].flush()
            writer.add_scalar("validation total loss", total_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation classifier loss", classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation entropy minimization loss", em_loss.data.cpu().float().item(), i)
            # writer.add_scalar("validation transfer loss", transfer_loss.data.cpu().float().item(), i)
            # writer.add_scalar("validation total fisher loss", fisher_loss.data.cpu().float().item(), i)
            # writer.add_scalar("validation intra-group fisher", fisher_intra_loss.data.cpu().float().item(), i)
            # writer.add_scalar("validation inter-group fisher", fisher_inter_loss.data.cpu().float().item(), i)
            
    return best_acc
def test(config):
    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    #train
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    #validate --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    #test for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=36,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=36,
                                        shuffle=True,
                                        num_workers=1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"],
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=4,
                                             shuffle=True,
                                             num_workers=1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"],
                                             batch_size=4,
                                             shuffle=True,
                                             num_workers=1)

    class_num = config["network"]["params"]["class_num"]

    # load checkpoint
    print('load model from {}'.format(config['ckpt_path']))

    ckpt = torch.load(config['ckpt_path'])
    print('recorded best training accuracy: {:0.4f} at step {}'.format(
        ckpt["train accuracy"], ckpt["step"]))
    print('recorded best validation accuracy: {:04f} at step {}'.format(
        ckpt["valid accuracy"], ckpt["step"]))

    train_accuracy = ckpt["train accuracy"]
    valid_accuracy = ckpt["valid accuracy"]
    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network.load_state_dict(ckpt['base_network'])

    centroids = None
    if 'center_criterion' in ckpt.keys():
        centroids = ckpt['center_criterion']['centers'].cpu()
    target_centroids = None
    if 'target_center_criterion' in ckpt.keys():
        target_centroids = ckpt['target_center_criterion']['centers'].cpu()

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## test
    print("start test: ")
    base_network.train(False)
    if config["ly_type"] == 'cosine':
        source_test_acc, source_test_confusion_matrix = image_classification_test(dset_loaders, "source_test", \
            base_network, \
            gpu=use_gpu)
        target_test_acc, target_test_confusion_matrix = image_classification_test(dset_loaders, "target_test", \
            base_network, \
            gpu=use_gpu)

    elif config["ly_type"] == "euclidean":
        eval_centroids = None
        if centroids is not None:
            eval_centroids = centroids
        if target_centroids is not None:
            eval_centroids = target_centroids
        source_test_acc, source_test_confusion_matrix = distance_classification_test(dset_loaders, "source_test", \
            base_network, eval_centroids, \
            gpu=use_gpu)
        target_test_acc, target_test_confusion_matrix = distance_classification_test(dset_loaders, "target_test", \
            base_network, eval_centroids, \
            gpu=use_gpu)

    # save train/test accuracy as pkl file
    #why do we want this in a pkl file?
    with open(os.path.join(config["output_path"], 'accuracy.pkl'),
              'wb') as pkl_file:
        pkl.dump(
            {
                'train': train_accuracy,
                'valid': valid_accuracy,
                'source test': source_test_acc,
                'target test': target_test_acc
            }, pkl_file)

    np.set_printoptions(precision=2)
    log_str = "train accuracy: {:.5f}\tvalid accuracy: {:5f}\nsource test accuracy: {:.5f}\nsource confusion matrix:\n{}\ntarget test accuracy: {:.5f}\ntarget confusion matrix:\n{}\n".format(
        train_accuracy, valid_accuracy, source_test_acc,
        source_test_confusion_matrix, target_test_acc,
        target_test_confusion_matrix)
    config["out_file"].write(log_str)
    config["out_file"].flush()
    print(log_str)

    return (source_test_acc, target_test_acc)
示例#7
0
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = loss_params["loss_type"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    #train
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    #validate --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    #test for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    #print the configuration you are using
    config["out_file"].write("config: {}\n".format(config))
    config["out_file"].flush()

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])  #,
    #max_iter_value=config["num_iterations"])
    if use_gpu:
        ad_net = ad_net.cuda()

        ## collect parameters
    if "DeepMerge" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    elif "ResNet18" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
            parameter_list.append({
                "params": ad_net.parameters(),
                "lr_mult": config["ad_net_mult_lr"],
                'decay_mult': 2
            })
            parameter_list.append({
                "params": center_criterion.parameters(),
                "lr_mult": 10,
                'decay_mult': 1
            })
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
            parameter_list.append({
                "params": ad_net.parameters(),
                "lr_mult": config["ad_net_mult_lr"],
                'decay_mult': 2
            })
            parameter_list.append({
                "params": center_criterion.parameters(),
                "lr_mult": 10,
                'decay_mult': 1
            })
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": config["ad_net_mult_lr"],
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    #Should I put lr_mult here as 1 for DeepMerge too? Probably!

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    scan_lr = []
    scan_loss = []

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    len_valid_source = len(dset_loaders["source_valid"])
    len_valid_target = len(dset_loaders["target_valid"])

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \
                    base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = distance_classification_test(
                    dset_loaders,
                    'source',
                    base_network,
                    center_criterion.centers.detach(),
                    gpu=use_gpu,
                    verbose=False,
                    save_where=None)
            else:
                raise ValueError("no test method for cls loss: {}".format(
                    config['loss']['ly_type']))

            snapshot_obj = {
                'epoch': i / len(dset_loaders["source"]),
                "base_network": base_network.state_dict(),
                'valid accuracy': temp_acc,
                'train accuracy': train_acc,
            }

            if (i + 1) % config["snapshot_interval"] == 0:
                torch.save(
                    snapshot_obj,
                    osp.join(
                        config["output_path"], "epoch_{}_model.pth.tar".format(
                            i / len(dset_loaders["source"]))))

            if config["loss"]["loss_name"] != "laplacian" and config["loss"][
                    "ly_type"] == "euclidean":
                snapshot_obj['center_criterion'] = center_criterion.state_dict(
                )

            if temp_acc > best_acc:
                best_acc = temp_acc

                # save best model
                torch.save(
                    snapshot_obj,
                    osp.join(config["output_path"], "best_model.pth.tar"))

            log_str = "epoch: {}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(
                i / len(dset_loaders["source"]), config['loss']['ly_type'],
                temp_acc, config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc,
                              i / len(dset_loaders["source"]))
            writer.add_scalar("training accuracy", train_acc,
                              i / len(dset_loaders["source"]))

        ## train one iter
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()
        scan_lr.append(optim['param_groups'][0]['lr'])

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
            inputs_target, labels_target = iter(dset_loaders["target"]).next()
        except StopIteration:
            iter(dset_loaders["source"])
            iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                            weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())
        # fisher loss on labeled source domain

        if config["fisher_or_no"] == 'no':
            total_loss = loss_params["trade_off"] * transfer_loss \
            + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            ######################################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 50 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     domain_classifier=ad_net,
                                     num_of_samples=100,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            ##########################################

            # if center_grad is not None:
            #     # clear mmc_loss
            #     center_criterion.centers.grad.zero_()
            #     # Manually assign centers gradients other than using autograd
            #     center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f},'
                    'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        ad_acc,
                        source_acc_ad,
                        target_acc_ad,
                    ))
                config['out_file'].flush()
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training source+target domain accuracy",
                                  ad_acc, i / len(dset_loaders["source"]))
                writer.add_scalar("training source domain accuracy",
                                  source_acc_ad,
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training target domain accuracy",
                                  target_acc_ad,
                                  i / len(dset_loaders["source"]))

                #attempted validation step
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        ad_net.train(False)
                        weight_ad = torch.ones(valid_inputs.size(0))
                        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                                           weight_ad, use_gpu)
                        ad_out, _ = ad_net(features.detach())
                        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(
                            ad_out)

                        # source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        #if config["fisher_or_no"] == 'no':
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                    + classifier_loss

                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f},'
                            'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                ad_acc,
                                source_acc_ad,
                                target_acc_ad,
                            ))
                        config['out_file'].flush()
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation source+target domain accuracy", ad_acc,
                            i / len(dset_loaders["source"]))
                        writer.add_scalar("validation source domain accuracy",
                                          source_acc_ad,
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar("validation target domain accuracy",
                                          target_acc_ad,
                                          i / len(dset_loaders["source"]))

                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

        else:
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=config["loss"]["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # final loss
            total_loss = loss_params["trade_off"] * transfer_loss \
                         + fisher_loss \
                         + loss_params["em_loss_coef"] * em_loss \
                         + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            ######################################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 50 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     num_of_samples=50,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            ##########################################

            if center_grad is not None:
                # clear mmc_loss
                center_criterion.centers.grad.zero_()
                # Manually assign centers gradients other than using autograd
                center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}, '
                    'train entropy min loss={:0.4f}, '
                    'train fisher loss={:0.4f}, train intra-group fisher loss={:0.4f}, train inter-group fisher loss={:0.4f}, '
                    'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        em_loss.data.cpu().float().item(),
                        fisher_loss.cpu().float().item(),
                        fisher_intra_loss.cpu().float().item(),
                        fisher_inter_loss.cpu().float().item(),
                        ad_acc,
                        source_acc_ad,
                        target_acc_ad,
                    ))

                config['out_file'].flush()
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training entropy minimization loss",
                                  em_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training total fisher loss",
                                  fisher_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training intra-group fisher",
                                  fisher_intra_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training inter-group fisher",
                                  fisher_inter_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training source+target domain accuracy",
                                  ad_acc, i / len(dset_loaders["source"]))
                writer.add_scalar("training source domain accuracy",
                                  source_acc_ad,
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training target domain accuracy",
                                  target_acc_ad,
                                  i / len(dset_loaders["source"]))

                #attempted validation step
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        ad_net.train(False)
                        weight_ad = torch.ones(valid_inputs.size(0))
                        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                   weight_ad, use_gpu)
                        ad_out, _ = ad_net(features.detach())
                        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(
                            ad_out)

                        # source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        # fisher loss on labeled source domain
                        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                            features.narrow(0, 0,
                                            int(valid_inputs.size(0) / 2)),
                            labels_valid_source,
                            inter_class=loss_params["inter_type"],
                            intra_loss_weight=loss_params["intra_loss_coef"],
                            inter_loss_weight=loss_params["inter_loss_coef"])
                        # entropy minimization loss
                        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

                        # final loss
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                     + fisher_loss \
                                     + loss_params["em_loss_coef"] * em_loss \
                                     + classifier_loss

                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}, '
                            'valid entropy min loss={:0.4f}, '
                            'valid fisher loss={:0.4f}, valid intra-group fisher loss={:0.4f}, valid inter-group fisher loss={:0.4f}, '
                            'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                em_loss.data.cpu().float().item(),
                                fisher_loss.cpu().float().item(),
                                fisher_intra_loss.cpu().float().item(),
                                fisher_inter_loss.cpu().float().item(),
                                ad_acc,
                                source_acc_ad,
                                target_acc_ad,
                            ))

                        config['out_file'].flush()
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation total fisher loss",
                            fisher_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation intra-group fisher",
                            fisher_intra_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation inter-group fisher",
                            fisher_inter_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation source+target domain accuracy", ad_acc,
                            i / len(dset_loaders["source"]))
                        writer.add_scalar("validation source domain accuracy",
                                          source_acc_ad,
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar("validation target domain accuracy",
                                          target_acc_ad,
                                          i / len(dset_loaders["source"]))

                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

    return best_acc
示例#8
0
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    class_num = config["network"]["params"]["class_num"]

    class_criterion = nn.CrossEntropyLoss()

    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=36,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=4,
                                             shuffle=True,
                                             num_workers=1)

    config['out_file'].write("dataset sizes: source={}\n".format(
        len(dsets["source"])))

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if "DeepMerge" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
    elif net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_valid_source = len(dset_loaders["source_valid"]) - 1

    classifier_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, \
                    gpu=use_gpu)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, \
                    gpu=use_gpu)
            # you can't use the euclidean distance_loss because it involves the target domain

            else:
                raise ValueError("no test method for cls loss: {}".format(
                    config['loss']['ly_type']))

            snapshot_obj = {
                'step': i,
                "base_network": base_network.state_dict(),
                'valid accuracy': temp_acc,
                'train accuracy': train_acc,
            }
            if temp_acc > best_acc:
                best_acc = temp_acc
                # save best model
                torch.save(
                    snapshot_obj,
                    osp.join(config["output_path"], "best_model.pth.tar"))
            log_str = "iter: {:05d}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(
                i, config['loss']['ly_type'], temp_acc,
                config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc, i)
            writer.add_scalar("training accuracy", train_acc, i)

            if early_stop_engine.is_stop_training(temp_acc):
                config["out_file"].write(
                    "no improvement after {}, stop training at step {}\n".
                    format(config["early_stop_patience"], i))
                # config["out_file"].write("finish training! \n")
                break

        if (i + 1) % config["snapshot_interval"] == 0:
            torch.save(
                snapshot_obj,
                osp.join(config["output_path"],
                         "iter_{:05d}_model.pth.tar".format(i)))

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])

        try:
            inputs_source, labels_source = iter_source.next()
        except StopIteration:
            iter_source = iter(dset_loaders["source"])

        if use_gpu:
            inputs_source, labels_source = Variable(
                inputs_source).cuda(), Variable(labels_source).cuda()
        else:
            inputs_source, labels_source = Variable(inputs_source), Variable(
                labels_source)

        inputs = inputs_source
        source_batch_size = inputs_source.size(0)

        features, logits = base_network(inputs)
        source_logits = logits.narrow(0, 0, source_batch_size)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())

        # entropy minimization loss
        #em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))
        #total_loss = loss_params["em_loss_coef"] * em_loss + classifier_loss

        total_loss = classifier_loss

        total_loss.backward()

        optimizer.step()

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: train total loss={:0.4f}, train classifier loss={:0.4f}\n'.format(i, \
                total_loss.data.cpu(), classifier_loss.data.cpu().float().item(),))
            config['out_file'].flush()
            writer.add_scalar("training total loss",
                              total_loss.data.cpu().float().item(), i)
            writer.add_scalar("training classifier loss",
                              classifier_loss.data.cpu().float().item(), i)
            #writer.add_scalar("training entropy minimization loss", em_loss.data.cpu().float().item(), i)

        #attempted validation step
        base_network.eval()
        with torch.no_grad():
            if i % len_valid_source == 0:
                iter_source = iter(dset_loaders["source_valid"])

            try:
                inputs_source, labels_source = iter_source.next()

            except StopIteration:
                iter_source = iter(dset_loaders["source_valid"])

            if use_gpu:
                inputs_source, labels_source = Variable(
                    inputs_source).cuda(), Variable(labels_source).cuda()
            else:
                inputs_source, labels_source = Variable(
                    inputs_source), Variable(labels_source)

            inputs = inputs_source
            source_batch_size = inputs_source.size(0)

            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)

            # source domain classification task loss
            classifier_loss = class_criterion(source_logits,
                                              labels_source.long())

            # entropy minimization loss
            #em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # final loss
            #total_loss = loss_params["em_loss_coef"] * em_loss + classifier_loss
            total_loss = classifier_loss
            #total_loss.backward() no backprop on the eval mode

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: valid total loss={:0.4f}, valid classifier loss={:0.4f}\n'.format(i, \
                total_loss.data.cpu(), classifier_loss.data.cpu().float().item(),))
            config['out_file'].flush()
            writer.add_scalar("validation total loss",
                              total_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation classifier loss",
                              classifier_loss.data.cpu().float().item(), i)
            #writer.add_scalar("training entropy minimization loss", em_loss.data.cpu().float().item(), i)

    return best_acc