def train(config, data_import):
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = config["loss"]["discriminant_loss"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    pristine_x, pristine_y, noisy_x, noisy_y = data_import
    dsets = {}
    dset_loaders = {}

    #sampling WOR
    pristine_indices = torch.randperm(len(pristine_x))

    pristine_x_train = pristine_x[pristine_indices]
    pristine_y_train = pristine_y[pristine_indices]

    noisy_indices = torch.randperm(len(noisy_x))
    noisy_x_train = noisy_x[noisy_indices]
    noisy_y_train = noisy_y[noisy_indices]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    if config["ckpt_path"] is not None:
        ckpt = torch.load(config['ckpt_path'] + '/best_model.pth.tar',
                          map_location=torch.device('cpu'))
        base_network.load_state_dict(ckpt['base_network'])

    use_gpu = torch.cuda.is_available()

    if use_gpu:
        base_network = base_network.cuda()

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])

    if use_gpu:
        ad_net = ad_net.cuda()

        ## collect parameters
    if "DeepMerge" in config["net"]:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    elif "ResNet18" in config["net"]:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

        if net_config["params"]["new_cls"]:
            if net_config["params"]["use_bottleneck"]:
                parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                                {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                                {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
                parameter_list.append({
                    "params": ad_net.parameters(),
                    "lr_mult": config["ad_net_mult_lr"],
                    'decay_mult': 2
                })
                parameter_list.append({
                    "params": center_criterion.parameters(),
                    "lr_mult": 10,
                    'decay_mult': 1
                })
            else:
                parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                                {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
                parameter_list.append({
                    "params": ad_net.parameters(),
                    "lr_mult": config["ad_net_mult_lr"],
                    'decay_mult': 2
                })
                parameter_list.append({
                    "params": center_criterion.parameters(),
                    "lr_mult": 10,
                    'decay_mult': 1
                })
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": config["ad_net_mult_lr"],
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):

        ## train one iter
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)
        elif config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
            inputs_target, labels_target = iter(dset_loaders["target"]).next()
        except StopIteration:
            iter(dset_loaders["source"])
            iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                            weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())

        # fisher loss on labeled source domain
        if config["fisher_or_no"] == 'no':
            total_loss = loss_params["trade_off"] * transfer_loss \
            + classifier_loss

            total_loss = classifier_loss + classifier_loss * (
                0.5 - source_acc_ad)**2 + classifier_loss * (0.5 -
                                                             target_acc_ad)**2

            total_loss.backward()

            optimizer.step()

        else:
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=config["loss"]["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # final loss
            total_loss = loss_params["trade_off"] * transfer_loss \
                         + fisher_loss \
                         + loss_params["em_loss_coef"] * em_loss \
                         + classifier_loss

            total_loss = classifier_loss + fisher_loss + loss_params[
                "em_loss_coef"] * em_loss + classifier_loss * (
                    0.5 - source_acc_ad)**2 + classifier_loss * (
                        0.5 - target_acc_ad)**2

            total_loss.backward()

            if center_grad is not None:
                # clear mmc_loss
                center_criterion.centers.grad.zero_()
                # Manually assign centers gradients other than using autograd
                center_criterion.centers.backward(center_grad)

            optimizer.step()

    return (-1 * total_loss.cpu().float().item())
예제 #2
0
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
               
    ## set loss
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = loss_params["loss_type"](num_classes=class_num, 
                                       feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), prep_config["source_size"]), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=1)
    dsets["target"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), prep_config["target_size"]), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=1)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

            dsets["target"+str(i)] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)
    else:
        dsets["test"] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

        dsets["target_test"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])


    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{"params":base_network.parameters(), "lr_mult":1, 'decay_mult':2}]

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(high_value=config["high"]) #, 
                                                      #max_iter_value=config["num_iterations"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params":ad_net.parameters(), "lr_mult":10, 'decay_mult':2})
    parameter_list.append({"params":center_criterion.parameters(), "lr_mult": 10, 'decay_mult':1})
 
    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]


    ## train   
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc = image_classification_test(dset_loaders, \
                    base_network, test_10crop=prep_config["test_10crop"], \
                    gpu=use_gpu)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, \
                    base_network, center_criterion.centers.detach(), test_10crop=prep_config["test_10crop"], \
                    gpu=use_gpu)
            else:
                raise ValueError("no test method for cls loss: {}".format(config['loss']['ly_type']))
            
            snapshot_obj = {'step': i, 
                            "base_network": base_network.state_dict(), 
                            'precision': temp_acc, 
                            }
            if config["loss"]["loss_name"] != "laplacian" and config["loss"]["ly_type"] == "euclidean":
                snapshot_obj['center_criterion'] = center_criterion.state_dict()
            if temp_acc > best_acc:
                best_acc = temp_acc
                # save best model
                torch.save(snapshot_obj, 
                           osp.join(config["output_path"], "best_model.pth.tar"))
            log_str = "iter: {:05d}, {} precision: {:.5f}\n".format(i, config['loss']['ly_type'], temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("precision", temp_acc, i)

            if early_stop_engine.is_stop_training(temp_acc):
                config["out_file"].write("no improvement after {}, stop training at step {}\n".format(
                    config["early_stop_patience"], i))
                # config["out_file"].write("finish training! \n")
                break

        if (i+1) % config["snapshot_interval"] == 0:
            torch.save(snapshot_obj, 
                       osp.join(config["output_path"], "iter_{:05d}_model.pth.tar".format(i)))
        

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)
           
        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                           weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source)
        # fisher loss on labeled source domain
        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(features.narrow(0, 0, int(inputs.size(0)/2)), labels_source, inter_class=config["loss"]["inter_type"], 
                                                                               intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"])
        # entropy minimization loss
        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

        # final loss
        total_loss = loss_params["trade_off"] * transfer_loss \
                     + fisher_loss \
                     + loss_params["em_loss_coef"] * em_loss \
                     + classifier_loss

        total_loss.backward()
        if center_grad is not None:
            # clear mmc_loss
            center_criterion.centers.grad.zero_()
            # Manually assign centers gradients other than using autograd
            center_criterion.centers.backward(center_grad)

        optimizer.step()

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: total loss={:0.4f}, transfer loss={:0.4f}, cls loss={:0.4f}, '
                'em loss={:0.4f}, '
                'mmc loss={:0.4f}, intra loss={:0.4f}, inter loss={:0.4f}, '
                'ad acc={:0.4f}, source_acc={:0.4f}, target_acc={:0.4f}\n'.format(
                i, total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), 
                em_loss.data.cpu().float().item(), 
                fisher_loss.cpu().float().item(), fisher_intra_loss.cpu().float().item(), fisher_inter_loss.cpu().float().item(),
                ad_acc, source_acc_ad, target_acc_ad, 
                ))

            config['out_file'].flush()
            writer.add_scalar("total_loss", total_loss.data.cpu().float().item(), i)
            writer.add_scalar("cls_loss", classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("transfer_loss", transfer_loss.data.cpu().float().item(), i)
            writer.add_scalar("ad_acc", ad_acc, i)
            writer.add_scalar("d_loss/total", fisher_loss.data.cpu().float().item(), i)
            writer.add_scalar("d_loss/intra", fisher_intra_loss.data.cpu().float().item(), i)
            writer.add_scalar("d_loss/inter", fisher_inter_loss.data.cpu().float().item(), i)
        
    return best_acc
def train(config):
    tie = 1.0
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
        resize_size=prep_config["resize_size"], \
        crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
        resize_size=prep_config["resize_size"], \
        crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
            resize_size=prep_config["resize_size"], \
            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
            resize_size=prep_config["resize_size"], \
            crop_size=prep_config["crop_size"])

    ## set loss
    class_criterion = nn.CrossEntropyLoss()  # 分类损失
    transfer_criterion1 = loss.PADA  # 迁移学习损失(带权重),BCEloss
    balance_criterion = loss.balance_loss()
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
                                                  batch_size=data_config["source"]["batch_size"], \
                                                  shuffle=True, num_workers=4)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
                                                  batch_size=data_config["target"]["batch_size"], \
                                                  shuffle=True, num_workers=4)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test" + str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                               transform=prep_dict["test"]["val" + str(i)])
            dset_loaders["test" + str(i)] = util_data.DataLoader(dsets["test" + str(i)], \
                                                                 batch_size=data_config["test"]["batch_size"], \
                                                                 shuffle=False, num_workers=4)

            dsets["target" + str(i)] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                                 transform=prep_dict["test"]["val" + str(i)])
            dset_loaders["target" + str(i)] = util_data.DataLoader(dsets["target" + str(i)], \
                                                                   batch_size=data_config["test"]["batch_size"], \
                                                                   shuffle=False, num_workers=4)
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                                    batch_size=data_config["test"]["batch_size"], \
                                                    shuffle=False, num_workers=4)

        dsets["target_test"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                         transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                                   batch_size=data_config["test"]["batch_size"], \
                                                   shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params": base_network.feature_layers.parameters(), "lr": 1}, \
                              {"params": base_network.bottleneck.parameters(), "lr": 10}, \
                              {"params": base_network.fc.parameters(), "lr": 10}]
        else:
            parameter_list = [{"params": base_network.feature_layers.parameters(), "lr": 1}, \
                              {"params": base_network.fc.parameters(), "lr": 10}]
    else:
        parameter_list = [{"params": base_network.parameters(), "lr": 1}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array(
        [1.0] * class_num))  # 权重初始化为class_num维向量,初始值为1
    if use_gpu:
        class_weight = class_weight.cuda()
    ad_net = network.AdversarialNetwork(
        base_network.output_num())  # 鉴别器设置,输入为特征提取器的维数,输出为属于共域的可能性
    nad_net = network.NAdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])
    if use_gpu:
        ad_net = ad_net.cuda()
        nad_net = nad_net.cuda()
    parameter_list.append({"params": ad_net.parameters(), "lr": 10})

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                                                     **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = attention_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:  # 测试
            base_network.train(False)  # 先对特征提取器进行训练
            temp_acc = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"], \
                                                 gpu=use_gpu)  # 对第一个有监督训练器进行分类训练,提取特征
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, alpha: {:03f}, tradeoff: {:03f} ,precision: {:.5f}".format(
                i, loss_params["para_alpha"], loss_params["trade_off"],
                temp_acc)
            config["out_file"].write(log_str + '\n')
            config["out_file"].flush()
            print(log_str)
        if i % config["snapshot_interval"] == 0:  # 输出
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                             "iter_{:05d}_model.pth.tar".format(i)))

        if i % loss_params["update_iter"] == loss_params["update_iter"] - 1:
            base_network.train(False)
            target_fc8_out = image_classification_predict(
                dset_loaders,
                base_network,
                softmax_param=config["softmax_param"]
            )  # 在对有监督学习进行了训练后,喂进去目标域数据,判断权重
            class_weight = torch.mean(
                target_fc8_out, 0)  # predict模型输出的预测向量取均值,为每一个体权重,将个体权重转化为类别权重
            class_weight = (class_weight /
                            torch.mean(class_weight)).cuda().view(-1)  # 归一化
            class_criterion = nn.CrossEntropyLoss(
                weight=class_weight)  # 这里的交叉熵损失函数就是带权重的

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                                                          Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target),
                           dim=0)  # 这里输入后一半目标域,前一半源域
        features, outputs = base_network(inputs)  # 从有分类网络获得特征与输出,到这里是特征提取器
        softmax_out = nn.Softmax(dim=1)(outputs).detach()
        ad_net.train(True)
        nad_net.train(True)
        weight_ad = torch.zeros(inputs.size(0))
        label_numpy = labels_source.data.cpu().numpy()
        for j in range(inputs.size(0) / 2):
            weight_ad[j] = class_weight[int(label_numpy[j])]  # 计算实际样例权重
        # print(label_numpy)
        weight_ad = weight_ad / torch.max(
            weight_ad[0:inputs.size(0) / 2])  # 权重归一化
        for j in range(inputs.size(0) / 2,
                       inputs.size(0)):  # 前一半源域,所以权重是计算的,后一半目标域,权重全为1
            weight_ad[j] = 1.0
        classifier_loss = class_criterion(
            outputs.narrow(0, 0,
                           inputs.size(0) / 2), labels_source)  # 分类损失
        total_loss = classifier_loss
        total_loss.backward()
        optimizer.step()

    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
예제 #4
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])

    ## set loss
    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]

    # seperate the source and validation set
    cls_source_list, cls_validation_list = sep.split_set(
        data_config["source"]["list_path"],
        config["network"]["params"]["class_num"])
    source_list = sep.dimension_rd(cls_source_list)

    dsets["source"] = ImageList(source_list, \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=4)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=4)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

            dsets["target"+str(i)] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

        dsets["target_test"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.bottleneck.parameters(), "lr":10}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
    else:
        parameter_list = [{"params": base_network.parameters(), "lr": 1}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params": ad_net.parameters(), "lr": 10})

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    best_model = 0
    for i in range(config["num_iterations"]):
        if (i + 1) % config["test_interval"] == 0:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"], \
                gpu=use_gpu)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            print(log_str)
        if (i + 1) % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        # if i % loss_params["update_iter"] == loss_params["update_iter"] - 1:
        #     base_network.train(False)
        #     target_fc8_out = image_classification_predict(dset_loaders, base_network, softmax_param=config["softmax_param"])
        #     class_weight = torch.mean(target_fc8_out, 0)
        #     class_weight = (class_weight / torch.mean(class_weight)).cuda().view(-1)
        #     class_criterion = nn.CrossEntropyLoss(weight = class_weight)

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        features, outputs = base_network(inputs)

        # #
        # if i % 100 == 0:
        #     check = dev.get_label_list(open(data_config["source"]["list_path"]).readlines(),
        #                                base_network,
        #                                prep_config["resize_size"],
        #                                prep_config["crop_size"],
        #                                data_config["target"]["batch_size"],
        #                                use_gpu)
        #     f = open("Class_result.txt", "a+")
        #     f.close()
        #     for cls in range(class_num):
        #         count = 0
        #         for j in check:
        #             if int(j.split(" ")[1].replace("\n", "")) == cls:
        #                 count = count + 1
        #         f = open("Source_result.txt", "a+")
        #         f.write("Source_Class: " + str(cls) + "\n" + "Number of images: " + str(count) + "\n")
        #         f.close()
        #
        #     check = dev.get_label_list(open(data_config["target"]["list_path"]).readlines(),
        #                                base_network,
        #                                prep_config["resize_size"],
        #                                prep_config["crop_size"],
        #                                data_config["target"]["batch_size"],
        #                                use_gpu)
        #     f = open("Class_result.txt", "a+")
        #     f.write("Iteration: " + str(i) + "\n")
        #     f.close()
        #     for cls in range(class_num):
        #         count = 0
        #         for j in check:
        #             if int(j.split(" ")[1].replace("\n", "")) == cls:
        #                 count = count + 1
        #         f = open("Class_result.txt", "a+")
        #         f.write("Target_Class: " + str(cls) + "\n" + "Number of images: " + str(count) + "\n")
        #         f.close()

        #
        # #
        # print("Training test:")
        # print(features)
        # print(features.shape)
        # print(outputs)
        # print(outputs.shape)

        softmax_out = nn.Softmax(dim=1)(outputs).detach()
        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        # label_numpy = labels_source.data.cpu().numpy()
        # for j in range(int(inputs.size(0) / 2)):
        #     weight_ad[j] = class_weight[int(label_numpy[j])]
        # weight_ad = weight_ad / torch.max(weight_ad[0:int(inputs.size(0)/2)])
        # for j in range(int(inputs.size(0) / 2), inputs.size(0)):
        #     weight_ad[j] = 1.0
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                           weight_ad, use_gpu)

        classifier_loss = class_criterion(
            outputs.narrow(0, 0, int(inputs.size(0) / 2)), labels_source)

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    cv_loss = dev.cross_validation_loss(
        base_network, base_network, cls_source_list,
        data_config["target"]["list_path"], cls_validation_list, class_num,
        prep_config["resize_size"], prep_config["crop_size"],
        data_config["target"]["batch_size"], use_gpu)
    print(cv_loss)
예제 #5
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])

    ## set loss
    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=4)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=4)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

            dsets["target"+str(i)] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

        dsets["target_test"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.bottleneck.parameters(), "lr":10}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
    else:
        parameter_list = [{"params": base_network.parameters(), "lr": 1}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params": ad_net.parameters(), "lr": 10})

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"], \
                gpu=use_gpu)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            print(log_str)
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        features, outputs = base_network(inputs)

        softmax_out = nn.Softmax(dim=1)(outputs).detach()
        ad_net.train(True)
        transfer_loss = transfer_criterion(features, ad_net,
                                           gradient_reverse_layer, use_gpu)

        classifier_loss = class_criterion(
            outputs.narrow(0, 0,
                           inputs.size(0) / 2), labels_source)

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
예제 #6
0
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = loss_params["loss_type"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    #train
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    #validate --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    #test for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    #print the configuration you are using
    config["out_file"].write("config: {}\n".format(config))
    config["out_file"].flush()

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])  #,
    #max_iter_value=config["num_iterations"])
    if use_gpu:
        ad_net = ad_net.cuda()

        ## collect parameters
    if "DeepMerge" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    elif "ResNet18" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
            parameter_list.append({
                "params": ad_net.parameters(),
                "lr_mult": config["ad_net_mult_lr"],
                'decay_mult': 2
            })
            parameter_list.append({
                "params": center_criterion.parameters(),
                "lr_mult": 10,
                'decay_mult': 1
            })
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
            parameter_list.append({
                "params": ad_net.parameters(),
                "lr_mult": config["ad_net_mult_lr"],
                'decay_mult': 2
            })
            parameter_list.append({
                "params": center_criterion.parameters(),
                "lr_mult": 10,
                'decay_mult': 1
            })
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": config["ad_net_mult_lr"],
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    #Should I put lr_mult here as 1 for DeepMerge too? Probably!

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    scan_lr = []
    scan_loss = []

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    len_valid_source = len(dset_loaders["source_valid"])
    len_valid_target = len(dset_loaders["target_valid"])

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \
                    base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = distance_classification_test(
                    dset_loaders,
                    'source',
                    base_network,
                    center_criterion.centers.detach(),
                    gpu=use_gpu,
                    verbose=False,
                    save_where=None)
            else:
                raise ValueError("no test method for cls loss: {}".format(
                    config['loss']['ly_type']))

            snapshot_obj = {
                'epoch': i / len(dset_loaders["source"]),
                "base_network": base_network.state_dict(),
                'valid accuracy': temp_acc,
                'train accuracy': train_acc,
            }

            if (i + 1) % config["snapshot_interval"] == 0:
                torch.save(
                    snapshot_obj,
                    osp.join(
                        config["output_path"], "epoch_{}_model.pth.tar".format(
                            i / len(dset_loaders["source"]))))

            if config["loss"]["loss_name"] != "laplacian" and config["loss"][
                    "ly_type"] == "euclidean":
                snapshot_obj['center_criterion'] = center_criterion.state_dict(
                )

            if temp_acc > best_acc:
                best_acc = temp_acc

                # save best model
                torch.save(
                    snapshot_obj,
                    osp.join(config["output_path"], "best_model.pth.tar"))

            log_str = "epoch: {}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(
                i / len(dset_loaders["source"]), config['loss']['ly_type'],
                temp_acc, config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc,
                              i / len(dset_loaders["source"]))
            writer.add_scalar("training accuracy", train_acc,
                              i / len(dset_loaders["source"]))

        ## train one iter
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()
        scan_lr.append(optim['param_groups'][0]['lr'])

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
            inputs_target, labels_target = iter(dset_loaders["target"]).next()
        except StopIteration:
            iter(dset_loaders["source"])
            iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                            weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())
        # fisher loss on labeled source domain

        if config["fisher_or_no"] == 'no':
            total_loss = loss_params["trade_off"] * transfer_loss \
            + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            ######################################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 50 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     domain_classifier=ad_net,
                                     num_of_samples=100,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            ##########################################

            # if center_grad is not None:
            #     # clear mmc_loss
            #     center_criterion.centers.grad.zero_()
            #     # Manually assign centers gradients other than using autograd
            #     center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f},'
                    'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        ad_acc,
                        source_acc_ad,
                        target_acc_ad,
                    ))
                config['out_file'].flush()
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training source+target domain accuracy",
                                  ad_acc, i / len(dset_loaders["source"]))
                writer.add_scalar("training source domain accuracy",
                                  source_acc_ad,
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training target domain accuracy",
                                  target_acc_ad,
                                  i / len(dset_loaders["source"]))

                #attempted validation step
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        ad_net.train(False)
                        weight_ad = torch.ones(valid_inputs.size(0))
                        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                                           weight_ad, use_gpu)
                        ad_out, _ = ad_net(features.detach())
                        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(
                            ad_out)

                        # source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        #if config["fisher_or_no"] == 'no':
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                    + classifier_loss

                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f},'
                            'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                ad_acc,
                                source_acc_ad,
                                target_acc_ad,
                            ))
                        config['out_file'].flush()
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation source+target domain accuracy", ad_acc,
                            i / len(dset_loaders["source"]))
                        writer.add_scalar("validation source domain accuracy",
                                          source_acc_ad,
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar("validation target domain accuracy",
                                          target_acc_ad,
                                          i / len(dset_loaders["source"]))

                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

        else:
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=config["loss"]["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # final loss
            total_loss = loss_params["trade_off"] * transfer_loss \
                         + fisher_loss \
                         + loss_params["em_loss_coef"] * em_loss \
                         + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            ######################################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 50 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     num_of_samples=50,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            ##########################################

            if center_grad is not None:
                # clear mmc_loss
                center_criterion.centers.grad.zero_()
                # Manually assign centers gradients other than using autograd
                center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}, '
                    'train entropy min loss={:0.4f}, '
                    'train fisher loss={:0.4f}, train intra-group fisher loss={:0.4f}, train inter-group fisher loss={:0.4f}, '
                    'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        em_loss.data.cpu().float().item(),
                        fisher_loss.cpu().float().item(),
                        fisher_intra_loss.cpu().float().item(),
                        fisher_inter_loss.cpu().float().item(),
                        ad_acc,
                        source_acc_ad,
                        target_acc_ad,
                    ))

                config['out_file'].flush()
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training entropy minimization loss",
                                  em_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training total fisher loss",
                                  fisher_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training intra-group fisher",
                                  fisher_intra_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training inter-group fisher",
                                  fisher_inter_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training source+target domain accuracy",
                                  ad_acc, i / len(dset_loaders["source"]))
                writer.add_scalar("training source domain accuracy",
                                  source_acc_ad,
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training target domain accuracy",
                                  target_acc_ad,
                                  i / len(dset_loaders["source"]))

                #attempted validation step
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        ad_net.train(False)
                        weight_ad = torch.ones(valid_inputs.size(0))
                        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                   weight_ad, use_gpu)
                        ad_out, _ = ad_net(features.detach())
                        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(
                            ad_out)

                        # source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        # fisher loss on labeled source domain
                        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                            features.narrow(0, 0,
                                            int(valid_inputs.size(0) / 2)),
                            labels_valid_source,
                            inter_class=loss_params["inter_type"],
                            intra_loss_weight=loss_params["intra_loss_coef"],
                            inter_loss_weight=loss_params["inter_loss_coef"])
                        # entropy minimization loss
                        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

                        # final loss
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                     + fisher_loss \
                                     + loss_params["em_loss_coef"] * em_loss \
                                     + classifier_loss

                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}, '
                            'valid entropy min loss={:0.4f}, '
                            'valid fisher loss={:0.4f}, valid intra-group fisher loss={:0.4f}, valid inter-group fisher loss={:0.4f}, '
                            'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                em_loss.data.cpu().float().item(),
                                fisher_loss.cpu().float().item(),
                                fisher_intra_loss.cpu().float().item(),
                                fisher_inter_loss.cpu().float().item(),
                                ad_acc,
                                source_acc_ad,
                                target_acc_ad,
                            ))

                        config['out_file'].flush()
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation total fisher loss",
                            fisher_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation intra-group fisher",
                            fisher_intra_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation inter-group fisher",
                            fisher_inter_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation source+target domain accuracy", ad_acc,
                            i / len(dset_loaders["source"]))
                        writer.add_scalar("validation source domain accuracy",
                                          source_acc_ad,
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar("validation target domain accuracy",
                                          target_acc_ad,
                                          i / len(dset_loaders["source"]))

                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

    return best_acc
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set pre-process
    # prep_dict = {}
    # prep_config = config["prep"]
    # prep_dict["source"] = prep.image_train( \
    #                         resize_size=prep_config["resize_size"], \
    #                         crop_size=prep_config["crop_size"])
    # prep_dict["target"] = prep.image_train( \
    #                         resize_size=prep_config["resize_size"], \
    #                         crop_size=prep_config["crop_size"])
    # if prep_config["test_10crop"]:
    #     prep_dict["test"] = prep.image_test_10crop( \
    #                         resize_size=prep_config["resize_size"], \
    #                         crop_size=prep_config["crop_size"])
    # else:
    #     prep_dict["test"] = prep.image_test( \
    #                         resize_size=prep_config["resize_size"], \
    #                         crop_size=prep_config["crop_size"])

    ## set loss
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = loss_params["loss_type"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    #train
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    #validate --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    #test for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=36,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=36,
                                        shuffle=True,
                                        num_workers=1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"],
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=4,
                                             shuffle=True,
                                             num_workers=1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"],
                                             batch_size=4,
                                             shuffle=True,
                                             num_workers=1)

    # data_config = config["data"]
    # dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), prep_config["source_size"]), \
    #                             transform=prep_dict["source"])
    # dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
    #         batch_size=data_config["source"]["batch_size"], \
    #         shuffle=True, num_workers=1)
    # dsets["target"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), prep_config["target_size"]), \
    #                             transform=prep_dict["target"])
    # dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
    #         batch_size=data_config["target"]["batch_size"], \
    #         shuffle=True, num_workers=1)

    # if prep_config["test_10crop"]:
    #     for i in range(10):
    #         dsets["test"+str(i)] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
    #                             transform=prep_dict["test"]["val"+str(i)])
    #         dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
    #                             batch_size=data_config["test"]["batch_size"], \
    #                             shuffle=False, num_workers=1)

    #         dsets["target"+str(i)] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
    #                             transform=prep_dict["test"]["val"+str(i)])
    #         dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
    #                             batch_size=data_config["test"]["batch_size"], \
    #                             shuffle=False, num_workers=1)
    # else:
    #     dsets["test"] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
    #                             transform=prep_dict["test"])
    #     dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
    #                             batch_size=data_config["test"]["batch_size"], \
    #                             shuffle=False, num_workers=1)

    #     dsets["target_test"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
    #                             transform=prep_dict["test"])
    #     dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
    #                             batch_size=data_config["test"]["batch_size"], \
    #                             shuffle=False, num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])  #,
    #max_iter_value=config["num_iterations"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({
        "params": ad_net.parameters(),
        "lr_mult": 10,
        'decay_mult': 2
    })
    parameter_list.append({
        "params": center_criterion.parameters(),
        "lr_mult": 10,
        'decay_mult': 1
    })

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, \
                    gpu=use_gpu)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, \
                    gpu=use_gpu)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \
                    base_network, center_criterion.centers.detach(), \
                    gpu=use_gpu)
                train_acc, _ = distance_classification_test(dset_loaders, 'source', \
                    base_network, \
                    gpu=use_gpu)
            else:
                raise ValueError("no test method for cls loss: {}".format(
                    config['loss']['ly_type']))

            snapshot_obj = {
                'step': i,
                "base_network": base_network.state_dict(),
                'precision': temp_acc,
                'train accuracy': train_acc,
            }
            if config["loss"]["loss_name"] != "laplacian" and config["loss"][
                    "ly_type"] == "euclidean":
                snapshot_obj['center_criterion'] = center_criterion.state_dict(
                )
            if temp_acc > best_acc:
                best_acc = temp_acc
                # save best model
                torch.save(
                    snapshot_obj,
                    osp.join(config["output_path"], "best_model.pth.tar"))
            log_str = "iter: {:05d}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(
                i, config['loss']['ly_type'], temp_acc,
                config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc, i)
            writer.add_scalar("training accuracy", train_acc, i)

            if early_stop_engine.is_stop_training(temp_acc):
                config["out_file"].write(
                    "no improvement after {}, stop training at step {}\n".
                    format(config["early_stop_patience"], i))
                # config["out_file"].write("finish training! \n")
                break

        if (i + 1) % config["snapshot_interval"] == 0:
            torch.save(
                snapshot_obj,
                osp.join(config["output_path"],
                         "iter_{:05d}_model.pth.tar".format(i)))

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                           weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source)
        # fisher loss on labeled source domain
        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
            features.narrow(0, 0, int(inputs.size(0) / 2)),
            labels_source,
            inter_class=config["loss"]["inter_type"],
            intra_loss_weight=loss_params["intra_loss_coef"],
            inter_loss_weight=loss_params["inter_loss_coef"])
        # entropy minimization loss
        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

        # final loss
        total_loss = loss_params["trade_off"] * transfer_loss \
                     + fisher_loss \
                     + loss_params["em_loss_coef"] * em_loss \
                     + classifier_loss

        total_loss.backward()

        if center_grad is not None:
            # clear mmc_loss
            center_criterion.centers.grad.zero_()
            # Manually assign centers gradients other than using autograd
            center_criterion.centers.backward(center_grad)

        optimizer.step()

        if i % config["log_iter"] == 0:
            config['out_file'].write(
                'iter {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}, '
                'train entropy min loss={:0.4f}, '
                'train fisher loss={:0.4f}, train intra-group fisher loss={:0.4f}, train inter-group fisher loss={:0.4f}, '
                'source+target domain accuracy={:0.4f}, source domain accuracy={:0.4f}, target domain accuracy={:0.4f}\n'
                .format(
                    i,
                    total_loss.data.cpu().float().item(),
                    transfer_loss.data.cpu().float().item(),
                    classifier_loss.data.cpu().float().item(),
                    em_loss.data.cpu().float().item(),
                    fisher_loss.cpu().float().item(),
                    fisher_intra_loss.cpu().float().item(),
                    fisher_inter_loss.cpu().float().item(),
                    ad_acc,
                    source_acc_ad,
                    target_acc_ad,
                ))

            config['out_file'].flush()
            writer.add_scalar("training total loss",
                              total_loss.data.cpu().float().item(), i)
            writer.add_scalar("training classifier loss",
                              classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("training transfer_loss",
                              transfer_loss.data.cpu().float().item(), i)
            writer.add_scalar("training total fisher loss",
                              fisher_loss.data.cpu().float().item(), i)
            writer.add_scalar("training intra-group fisher",
                              fisher_intra_loss.data.cpu().float().item(), i)
            writer.add_scalar("training inter-group fisher",
                              fisher_inter_loss.data.cpu().float().item(), i)
            writer.add_scalar("source+target domain accuracy", ad_acc, i)
            writer.add_scalar("source domain accuracy", source_acc_ad, i)
            writer.add_scalar("target domain accuracy", target_acc_ad, i)

        #attempted validation step
        #if i < len_valid_source:
        base_network.eval()
        with torch.no_grad():
            if i % len_valid_source == 0:
                iter_source = iter(dset_loaders["source_valid"])
            if i % len_valid_target == 0:
                iter_target = iter(dset_loaders["target_valid"])

            try:
                inputs_source, labels_source = iter_source.next()
                inputs_target, labels_target = iter_target.next()

            except StopIteration:
                iter_source = iter(dset_loaders["source_valid"])
                iter_target = iter(dset_loaders["target_valid"])

            if use_gpu:
                inputs_source, inputs_target, labels_source = \
                    Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                    Variable(labels_source).cuda()
            else:
                inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                    Variable(inputs_target), Variable(labels_source)

            inputs = torch.cat((inputs_source, inputs_target), dim=0)
            source_batch_size = inputs_source.size(0)

            if config['loss']['ly_type'] == 'cosine':
                features, logits = base_network(inputs)
                source_logits = logits.narrow(0, 0, source_batch_size)
            elif config['loss']['ly_type'] == 'euclidean':
                features, _ = base_network(inputs)
                logits = -1.0 * loss.distance_to_centroids(
                    features, center_criterion.centers.detach())
                source_logits = logits.narrow(0, 0, source_batch_size)

            ad_net.train(False)
            weight_ad = torch.ones(inputs.size(0))
            transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                               weight_ad, use_gpu)
            ad_out, _ = ad_net(features.detach())
            ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

            transfer_loss = transfer_criterion(features[:source_batch_size],
                                               features[source_batch_size:])

            # source domain classification task loss
            classifier_loss = class_criterion(source_logits,
                                              labels_source.long())
            # fisher loss on labeled source domain
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=loss_params["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # final loss
            total_loss = loss_params["trade_off"] * transfer_loss \
                         + fisher_loss \
                         + loss_params["em_loss_coef"] * em_loss \
                         + classifier_loss
            #total_loss.backward() no backprop on the eval mode

        if i % config["log_iter"] == 0:
            config['out_file'].write(
                'iter {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}, '
                'valid entropy min loss={:0.4f}, '
                'valid fisher loss={:0.4f}, valid intra-group fisher loss={:0.4f}, valid inter-group fisher loss={:0.4f}, '
                'source+target domain accuracy={:0.4f}, source domain accuracy={:0.4f}, target domain accuracy={:0.4f}\n'
                .format(
                    i,
                    total_loss.data.cpu().float().item(),
                    transfer_loss.data.cpu().float().item(),
                    classifier_loss.data.cpu().float().item(),
                    em_loss.data.cpu().float().item(),
                    fisher_loss.cpu().float().item(),
                    fisher_intra_loss.cpu().float().item(),
                    fisher_inter_loss.cpu().float().item(),
                    ad_acc,
                    source_acc_ad,
                    target_acc_ad,
                ))

            config['out_file'].flush()
            writer.add_scalar("validation total loss",
                              total_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation classifier loss",
                              classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation transfer_loss",
                              transfer_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation total fisher loss",
                              fisher_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation intra-group fisher",
                              fisher_intra_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation inter-group fisher",
                              fisher_inter_loss.data.cpu().float().item(), i)
            writer.add_scalar("source+target domain accuracy", ad_acc, i)
            writer.add_scalar("source domain accuracy", source_acc_ad, i)
            writer.add_scalar("target domain accuracy", target_acc_ad, i)

    return best_acc