def train(config, data_import):
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = config["loss"]["discriminant_loss"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    pristine_x, pristine_y, noisy_x, noisy_y = data_import
    dsets = {}
    dset_loaders = {}

    #sampling WOR
    pristine_indices = torch.randperm(len(pristine_x))

    pristine_x_train = pristine_x[pristine_indices]
    pristine_y_train = pristine_y[pristine_indices]

    noisy_indices = torch.randperm(len(noisy_x))
    noisy_x_train = noisy_x[noisy_indices]
    noisy_y_train = noisy_y[noisy_indices]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    if config["ckpt_path"] is not None:
        ckpt = torch.load(config['ckpt_path'] + '/best_model.pth.tar',
                          map_location=torch.device('cpu'))
        base_network.load_state_dict(ckpt['base_network'])

    use_gpu = torch.cuda.is_available()

    if use_gpu:
        base_network = base_network.cuda()

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])

    if use_gpu:
        ad_net = ad_net.cuda()

        ## collect parameters
    if "DeepMerge" in config["net"]:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    elif "ResNet18" in config["net"]:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

        if net_config["params"]["new_cls"]:
            if net_config["params"]["use_bottleneck"]:
                parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                                {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                                {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
                parameter_list.append({
                    "params": ad_net.parameters(),
                    "lr_mult": config["ad_net_mult_lr"],
                    'decay_mult': 2
                })
                parameter_list.append({
                    "params": center_criterion.parameters(),
                    "lr_mult": 10,
                    'decay_mult': 1
                })
            else:
                parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                                {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
                parameter_list.append({
                    "params": ad_net.parameters(),
                    "lr_mult": config["ad_net_mult_lr"],
                    'decay_mult': 2
                })
                parameter_list.append({
                    "params": center_criterion.parameters(),
                    "lr_mult": 10,
                    'decay_mult': 1
                })
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": config["ad_net_mult_lr"],
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):

        ## train one iter
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)
        elif config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
            inputs_target, labels_target = iter(dset_loaders["target"]).next()
        except StopIteration:
            iter(dset_loaders["source"])
            iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                            weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())

        # fisher loss on labeled source domain
        if config["fisher_or_no"] == 'no':
            total_loss = loss_params["trade_off"] * transfer_loss \
            + classifier_loss

            total_loss = classifier_loss + classifier_loss * (
                0.5 - source_acc_ad)**2 + classifier_loss * (0.5 -
                                                             target_acc_ad)**2

            total_loss.backward()

            optimizer.step()

        else:
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=config["loss"]["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # final loss
            total_loss = loss_params["trade_off"] * transfer_loss \
                         + fisher_loss \
                         + loss_params["em_loss_coef"] * em_loss \
                         + classifier_loss

            total_loss = classifier_loss + fisher_loss + loss_params[
                "em_loss_coef"] * em_loss + classifier_loss * (
                    0.5 - source_acc_ad)**2 + classifier_loss * (
                        0.5 - target_acc_ad)**2

            total_loss.backward()

            if center_grad is not None:
                # clear mmc_loss
                center_criterion.centers.grad.zero_()
                # Manually assign centers gradients other than using autograd
                center_criterion.centers.backward(center_grad)

            optimizer.step()

    return (-1 * total_loss.cpu().float().item())
Пример #2
0
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
               
    ## set loss
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = loss_params["loss_type"](num_classes=class_num, 
                                       feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), prep_config["source_size"]), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=1)
    dsets["target"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), prep_config["target_size"]), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=1)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

            dsets["target"+str(i)] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)
    else:
        dsets["test"] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

        dsets["target_test"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])


    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{"params":base_network.parameters(), "lr_mult":1, 'decay_mult':2}]

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(high_value=config["high"]) #, 
                                                      #max_iter_value=config["num_iterations"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params":ad_net.parameters(), "lr_mult":10, 'decay_mult':2})
    parameter_list.append({"params":center_criterion.parameters(), "lr_mult": 10, 'decay_mult':1})
 
    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]


    ## train   
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc = image_classification_test(dset_loaders, \
                    base_network, test_10crop=prep_config["test_10crop"], \
                    gpu=use_gpu)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, \
                    base_network, center_criterion.centers.detach(), test_10crop=prep_config["test_10crop"], \
                    gpu=use_gpu)
            else:
                raise ValueError("no test method for cls loss: {}".format(config['loss']['ly_type']))
            
            snapshot_obj = {'step': i, 
                            "base_network": base_network.state_dict(), 
                            'precision': temp_acc, 
                            }
            if config["loss"]["loss_name"] != "laplacian" and config["loss"]["ly_type"] == "euclidean":
                snapshot_obj['center_criterion'] = center_criterion.state_dict()
            if temp_acc > best_acc:
                best_acc = temp_acc
                # save best model
                torch.save(snapshot_obj, 
                           osp.join(config["output_path"], "best_model.pth.tar"))
            log_str = "iter: {:05d}, {} precision: {:.5f}\n".format(i, config['loss']['ly_type'], temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("precision", temp_acc, i)

            if early_stop_engine.is_stop_training(temp_acc):
                config["out_file"].write("no improvement after {}, stop training at step {}\n".format(
                    config["early_stop_patience"], i))
                # config["out_file"].write("finish training! \n")
                break

        if (i+1) % config["snapshot_interval"] == 0:
            torch.save(snapshot_obj, 
                       osp.join(config["output_path"], "iter_{:05d}_model.pth.tar".format(i)))
        

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)
           
        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                           weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source)
        # fisher loss on labeled source domain
        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(features.narrow(0, 0, int(inputs.size(0)/2)), labels_source, inter_class=config["loss"]["inter_type"], 
                                                                               intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"])
        # entropy minimization loss
        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

        # final loss
        total_loss = loss_params["trade_off"] * transfer_loss \
                     + fisher_loss \
                     + loss_params["em_loss_coef"] * em_loss \
                     + classifier_loss

        total_loss.backward()
        if center_grad is not None:
            # clear mmc_loss
            center_criterion.centers.grad.zero_()
            # Manually assign centers gradients other than using autograd
            center_criterion.centers.backward(center_grad)

        optimizer.step()

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: total loss={:0.4f}, transfer loss={:0.4f}, cls loss={:0.4f}, '
                'em loss={:0.4f}, '
                'mmc loss={:0.4f}, intra loss={:0.4f}, inter loss={:0.4f}, '
                'ad acc={:0.4f}, source_acc={:0.4f}, target_acc={:0.4f}\n'.format(
                i, total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), 
                em_loss.data.cpu().float().item(), 
                fisher_loss.cpu().float().item(), fisher_intra_loss.cpu().float().item(), fisher_inter_loss.cpu().float().item(),
                ad_acc, source_acc_ad, target_acc_ad, 
                ))

            config['out_file'].flush()
            writer.add_scalar("total_loss", total_loss.data.cpu().float().item(), i)
            writer.add_scalar("cls_loss", classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("transfer_loss", transfer_loss.data.cpu().float().item(), i)
            writer.add_scalar("ad_acc", ad_acc, i)
            writer.add_scalar("d_loss/total", fisher_loss.data.cpu().float().item(), i)
            writer.add_scalar("d_loss/intra", fisher_intra_loss.data.cpu().float().item(), i)
            writer.add_scalar("d_loss/inter", fisher_inter_loss.data.cpu().float().item(), i)
        
    return best_acc
def train(config):

    # Fix random seed and unable deterministic calcualtions
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True

    ## In case of highly imbalanced classes one can add weights inside Cross-entropy loss
    ## made by: [1 - (x / sum(nSamples)) for x in nSamples]
    #weights = [0.936,0.064]
    #class_weights = torch.FloatTensor(weights).cuda()

    # Set up summary writer
    writer = SummaryWriter(config['output_path'])

    class_num = config["network"]["params"]["class_num"]
    class_criterion = nn.CrossEntropyLoss(
    )  # optionally add "weight=class_weights" in case of higly imbalanced classes
    transfer_criterion = config["loss"]["name"]
    center_criterion = config["loss"]["discriminant_loss"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])
    loss_params = config["loss"]

    # Prepare image data. Image shuffling is fixed with the random seed choice.
    # Train:validation:test = 70:10:20
    dsets = {}
    dset_loaders = {}

    pristine_indices = torch.randperm(len(pristine_x))
    # Train sample
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    # Validation sample --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    # Test sample for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    # Train sample
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    # Validation sample --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    # Test sample for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    # Set number of epochs, and logging intervals
    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    # Print the configuration you are using
    config["out_file"].write("config: {}\n".format(config))
    config["out_file"].flush()

    # Set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    # Set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    #Loading trained model if we want:
    if config["ckpt_path"] is not None:
        print('load model from {}'.format(config['ckpt_path']))
        ckpt = torch.load(config['ckpt_path'] + '/best_model.pth.tar')
        base_network.load_state_dict(ckpt['base_network'])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    # Collect parameters for the chosen network to be trained
    if "DeepMerge" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
    elif net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 10,
            'decay_mult': 2
        }]

    # Class weights in case we need them, hewe we have balanced sample so weights are 1.0
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()

    parameter_list.append({
        "params": center_criterion.parameters(),
        "lr_mult": 10,
        'decay_mult': 1
    })

    # Set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))

    # Set learning rate scheduler
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    scan_lr = []
    scan_loss = []

    ###################
    ###### TRAIN ######
    ###################
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    len_valid_source = len(dset_loaders["source_valid"])
    len_valid_target = len(dset_loaders["target_valid"])

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):

        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \
                    base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = distance_classification_test(dset_loaders, 'source', \
                    base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose = False, save_where = None)
            else:
                raise ValueError("no test method for cls loss: {}".format(
                    config['loss']['ly_type']))

            snapshot_obj = {
                'epoch': i / len(dset_loaders["source"]),
                "base_network": base_network.state_dict(),
                'valid accuracy': temp_acc,
                'train accuracy': train_acc,
            }

            snapshot_obj['center_criterion'] = center_criterion.state_dict()

            if (i + 1) % config["snapshot_interval"] == 0:
                torch.save(
                    snapshot_obj,
                    osp.join(
                        config["output_path"], "epoch_{}_model.pth.tar".format(
                            i / len(dset_loaders["source"]))))

            if temp_acc > best_acc:
                best_acc = temp_acc
                # Save best model
                torch.save(
                    snapshot_obj,
                    osp.join(config["output_path"], "best_model.pth.tar"))

            # Write to log file
            log_str = "epoch: {}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(
                i / len(dset_loaders["source"]), config['loss']['ly_type'],
                temp_acc, config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc,
                              i / len(dset_loaders["source"]))
            writer.add_scalar("training accuracy", train_acc,
                              i / len(dset_loaders["source"]))

        ## Train one iteration
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()
        scan_lr.append(optim['param_groups'][0]['lr'])

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
            inputs_target, labels_target = iter(dset_loaders["target"]).next()

        except StopIteration:
            iter(dset_loaders["source"])
            iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        # Distance type. We use cosine.
        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        # Transfer loss - MMD
        transfer_loss = transfer_criterion(features[:source_batch_size],
                                           features[source_batch_size:])

        # Source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())

        # Final loss in case we do not want to add Fisher loss and Entropy minimization
        if config["fisher_or_no"] == 'no':
            total_loss = loss_params["trade_off"] * transfer_loss \
            + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            #################
            # Plot embeddings periodically. tSNE plots
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 20 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     num_of_samples=2000,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            #################

            optimizer.step()

            if i % config["log_iter"] == 0:

                # In case we want to do a learning rate scane to find best lr_cycle lengh:
                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                # In case we want to visualize gradients:
                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                # Logging:
                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                    ))

                config['out_file'].flush()

                # Logging for tensorboard
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))

                #################
                # Validation step
                #################
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        # Distance type. We use cosine.
                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        # Transfer loss - MMD
                        transfer_loss = transfer_criterion(
                            features[:valid_source_batch_size],
                            features[valid_source_batch_size:])

                        # Source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        # Final loss in case we do not want to add Fisher loss and Entropy minimization
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                    + classifier_loss

                    # Logging:
                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                            ))
                        config['out_file'].flush()
                        # Logging for tensorboard:
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))

                        # Early stop in case we see overfitting
                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

        # In case we want to add Fisher loss and Entropy minimizaiton
        else:
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=loss_params["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])

            # Entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # Final loss is the sum of all losses
            total_loss = loss_params["trade_off"] * transfer_loss \
                 + fisher_loss \
                 + loss_params["em_loss_coef"] * em_loss \
                 + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            #################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 20 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     num_of_samples=2000,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            #################

            if center_grad is not None:
                # Clear mmc_loss
                center_criterion.centers.grad.zero_()
                # Manually assign centers gradients other than using autograd
                center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                # In case we want to do a learning rate scane to find best lr_cycle lengh:
                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                # In case we want to visualize gradients:
                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                # Logging
                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}, '
                    'train entropy min loss={:0.4f}, '
                    'train fisher loss={:0.4f}, train intra-group fisher loss={:0.4f}, train inter-group fisher loss={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        em_loss.data.cpu().float().item(),
                        fisher_loss.cpu().float().item(),
                        fisher_intra_loss.cpu().float().item(),
                        fisher_inter_loss.cpu().float().item(),
                    ))
                config['out_file'].flush()
                # Logging for tensorboard
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training entropy minimization loss",
                                  em_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training total fisher loss",
                                  fisher_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training intra-group fisher",
                                  fisher_intra_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training inter-group fisher",
                                  fisher_inter_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))

                #################
                # Validation step
                #################
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()

                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        # Distance type. We use cosine.
                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        # Transfer loss - MMD
                        transfer_loss = transfer_criterion(
                            features[:valid_source_batch_size],
                            features[valid_source_batch_size:])

                        # Source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())
                        # Fisher loss on labeled source domain
                        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                            features.narrow(0, 0,
                                            int(valid_inputs.size(0) / 2)),
                            labels_valid_source,
                            inter_class=loss_params["inter_type"],
                            intra_loss_weight=loss_params["intra_loss_coef"],
                            inter_loss_weight=loss_params["inter_loss_coef"])

                        # Entropy minimization loss
                        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

                        # Final loss
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                     + fisher_loss \
                                     + loss_params["em_loss_coef"] * em_loss \
                                     + classifier_loss

                    # Logging
                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}, '
                            'valid entropy min loss={:0.4f}, '
                            'valid fisher loss={:0.4f}, valid intra-group fisher loss={:0.4f}, valid inter-group fisher loss={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                em_loss.data.cpu().float().item(),
                                fisher_loss.cpu().float().item(),
                                fisher_intra_loss.cpu().float().item(),
                                fisher_inter_loss.cpu().float().item(),
                            ))
                        config['out_file'].flush()
                        # Logging for tensorboard
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation total fisher loss",
                            fisher_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation intra-group fisher",
                            fisher_intra_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation inter-group fisher",
                            fisher_inter_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))

                        # Early stop in case we see overfitting
                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at epoch {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

    return best_acc
Пример #4
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])

    ## set loss
    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.SAN
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=4)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=4)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

    class_num = config["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.bottleneck.parameters(), "lr":10}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
    else:
        parameter_list = [{"params": base_network.parameters(), "lr": 1}]

    ## add additional network for some methods
    ad_net_list = []
    gradient_reverse_layer_list = []
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()

    for i in range(class_num):
        if config["dataset"] == "caltech":
            ad_net = network.SmallAdversarialNetwork(base_network.output_num())
        elif config["dataset"] == "imagenet":
            ad_net = network.LittleAdversarialNetwork(
                base_network.output_num())
        else:
            ad_net = network.AdversarialNetwork(base_network.output_num())
        gradient_reverse_layer = network.AdversarialLayer()
        gradient_reverse_layer_list.append(gradient_reverse_layer)
        if use_gpu:
            ad_net = ad_net.cuda()
        ad_net_list.append(ad_net)
        parameter_list.append({"params": ad_net.parameters(), "lr": 10})

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                    base_network, test_10crop=prep_config["test_10crop"], \
                    gpu=use_gpu)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            print(log_str)
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        features, outputs = base_network(inputs)
        ## switch between different transfer loss
        ad_net.train(True)
        softmax_out = nn.Softmax(dim=1)(outputs).detach()
        class_weight = torch.mean(softmax_out.data, 0)
        class_weight = (class_weight / torch.max(class_weight)).cuda().view(-1)
        for i in range(class_num):
            gradient_reverse_layer_list[i].high = class_weight[i]
        transfer_loss = transfer_criterion([features, softmax_out], ad_net_list, \
                        gradient_reverse_layer_list, class_weight, use_gpu) \
                        + loss_params["entropy_trade_off"] * \
                        loss.EntropyLoss(nn.Softmax(dim=1)(outputs))
        classifier_loss = class_criterion(
            outputs.narrow(0, 0,
                           inputs.size(0) / 2), labels_source)

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    class_num = config["network"]["params"]["class_num"]

    class_criterion = nn.CrossEntropyLoss()

    # transfer_criterion = config["loss"]["name"]
    #center_criterion = config["loss"]["discriminant_loss"](num_classes=class_num, feat_dim=config["network"]["params"]["bottleneck_dim"])
    
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[pristine_indices[:int(np.floor(.7*len(pristine_x)))]]
    pristine_y_train = pristine_y[pristine_indices[:int(np.floor(.7*len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[int(np.floor(.7*len(pristine_x))) : int(np.floor(.8*len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[int(np.floor(.7*len(pristine_x))) : int(np.floor(.8*len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[pristine_indices[int(np.floor(.8*len(pristine_x))):]]
    pristine_y_test = pristine_y[pristine_indices[int(np.floor(.8*len(pristine_x))):]]

    # noisy_indices = torch.randperm(len(noisy_x))
    # #train
    # noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7*len(noisy_x)))]]
    # noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7*len(noisy_x)))]]
    # #validate --- gets passed into test functions in train file
    # noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7*len(noisy_x))) : int(np.floor(.8*len(noisy_x)))]]
    # noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7*len(noisy_x))) : int(np.floor(.8*len(noisy_x)))]]
    # #test for evaluation file
    # noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8*len(noisy_x))):]]
    # noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8*len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    # dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    # dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    # dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)


    # data_config = config["data"]
    # dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), ratio=prep_config["source_size"]), \
    #                             transform=prep_dict["source"])

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size = 36, shuffle = True, num_workers = 1)
    # dset_loaders["target"] = DataLoader(dsets["target"], batch_size = 36, shuffle = True, num_workers = 1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"], batch_size = 4, shuffle = True, num_workers = 1)
    # dset_loaders["target_valid"] = DataLoader(dsets["target_valid"], batch_size = 4, shuffle = True, num_workers = 1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"], batch_size = 4, shuffle = True, num_workers = 1)
    # dset_loaders["target_test"] = DataLoader(dsets["target_test"], batch_size = 4, shuffle = True, num_workers = 1)

    config['out_file'].write("dataset sizes: source={}\n".format(
        len(dsets["source"]))) #TODO: change this too

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])


    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{"params":base_network.parameters(), "lr_mult":1, 'decay_mult':2}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()
    #parameter_list.append({"params":center_criterion.parameters(), "lr_mult": 10, 'decay_mult':1})
 
    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]


    ## train   
    len_train_source = len(dset_loaders["source"]) - 1
    # len_train_target = len(dset_loaders["target"]) - 1
    len_valid_source = len(dset_loaders["source_valid"]) - 1
    # len_valid_target = len(dset_loaders["target_valid"]) - 1

    classifier_loss_value = 0.0
    # transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, \
                    gpu=use_gpu)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, \
                    gpu=use_gpu)
            # elif config['loss']['ly_type'] == "euclidean":
            #     temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \
            #         base_network, center_criterion.centers.detach(), \
            #         gpu=use_gpu)
            #     train_acc, _ = distance_classification_test(dset_loaders, 'source', \
            #         base_network, \
            #         gpu=use_gpu)
            else:
                raise ValueError("no test method for cls loss: {}".format(config['loss']['ly_type']))
            
            snapshot_obj = {'step': i, 
                            "base_network": base_network.state_dict(), 
                            'valid accuracy': temp_acc,
                            'train accuracy' : train_acc,
                            }
            # snapshot_obj['center_criterion'] = center_criterion.state_dict()
            if temp_acc > best_acc:
                best_acc = temp_acc
                # save best model
                torch.save(snapshot_obj, 
                           osp.join(config["output_path"], "best_model.pth.tar"))
            log_str = "iter: {:05d}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(i, config['loss']['ly_type'], temp_acc, config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc, i)
            writer.add_scalar("training accuracy", train_acc, i)

            if early_stop_engine.is_stop_training(temp_acc):
                config["out_file"].write("no improvement after {}, stop training at step {}\n".format(
                    config["early_stop_patience"], i))
                # config["out_file"].write("finish training! \n")
                break

        if (i+1) % config["snapshot_interval"] == 0:
            torch.save(snapshot_obj, 
                        osp.join(config["output_path"], "iter_{:05d}_model.pth.tar".format(i)))
                    

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        # if i % len_train_target == 0:
        #     iter_target = iter(dset_loaders["target"])

        try:
            inputs_source, labels_source = iter_source.next()
            # inputs_target, labels_target = iter_target.next()
        except StopIteration:
            iter_source = iter(dset_loaders["source"])
            # iter_target = iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, labels_source = Variable(inputs_source).cuda(), Variable(labels_source).cuda()
        else:
            inputs_source, labels_source = Variable(inputs_source), Variable(labels_source)
           
        inputs = inputs_source
        source_batch_size = inputs_source.size(0)

        # if config['loss']['ly_type'] == 'cosine':
        features, logits = base_network(inputs)
        source_logits = logits.narrow(0, 0, source_batch_size)
        # elif config['loss']['ly_type'] == 'euclidean':
        #     features, _ = base_network(inputs)
        #     logits = -1.0 * loss.distance_to_centroids(features, center_criterion.centers.detach())
        #     source_logits = logits.narrow(0, 0, source_batch_size)

        # transfer_loss = transfer_criterion(features[:source_batch_size], features[source_batch_size:])

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())
        # fisher loss on labeled source domain
        # fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(features.narrow(0, 0, int(inputs.size(0)/2)), labels_source, inter_class=loss_params["inter_type"], 
                                                                               # intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"])
        # entropy minimization loss
        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

        total_loss = loss_params["em_loss_coef"] * em_loss + classifier_loss
   
        total_loss.backward() #we need to fix this

        # if center_grad is not None:
        #     # clear mmc_loss
        #     center_criterion.centers.grad.zero_()
        #     # Manually assign centers gradients other than using autograd
        #     center_criterion.centers.backward(center_grad)

        optimizer.step()

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: train total loss={:0.4f}, train classifier loss={:0.4f}, '
                'train entropy min loss={:0.4f}\n'.format(
                i, total_loss.data.cpu(), classifier_loss.data.cpu().float().item(), em_loss.data.cpu().float().item(), 
                ))
            config['out_file'].flush()
            writer.add_scalar("training total loss", total_loss.data.cpu().float().item(), i)
            writer.add_scalar("training classifier loss", classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("training entropy minimization loss", em_loss.data.cpu().float().item(), i)
            # writer.add_scalar("training transfer loss", transfer_loss.data.cpu().float().item(), i)
            # writer.add_scalar("training total fisher loss", fisher_loss.data.cpu().float().item(), i)
            # writer.add_scalar("training intra-group fisher", fisher_intra_loss.data.cpu().float().item(), i)
            # writer.add_scalar("training inter-group fisher", fisher_inter_loss.data.cpu().float().item(), i)

        #attempted validation step
        #if i < len_valid_source:
        base_network.eval()
        with torch.no_grad():
            if i % len_valid_source == 0:
                iter_source = iter(dset_loaders["source_valid"])
            # if i % len_valid_target == 0:
            #     iter_target = iter(dset_loaders["target_valid"])

            try:
                inputs_source, labels_source = iter_source.next()
                # inputs_target, labels_target = iter_target.next()

            except StopIteration:
                iter_source = iter(dset_loaders["source_valid"])
                # iter_target = iter(dset_loaders["target_valid"])

            if use_gpu:
                inputs_source, labels_source = Variable(inputs_source).cuda(), Variable(labels_source).cuda()
            else:
                inputs_source,labels_source = Variable(inputs_source), Variable(labels_source)
               
            inputs = inputs_source
            source_batch_size = inputs_source.size(0)

            # if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
            # elif config['loss']['ly_type'] == 'euclidean':
            #     features, _ = base_network(inputs)
            #     logits = -1.0 * loss.distance_to_centroids(features, center_criterion.centers.detach())
            #     source_logits = logits.narrow(0, 0, source_batch_size)

            # transfer_loss = transfer_criterion(features[:source_batch_size], features[source_batch_size:])

            # source domain classification task loss
            classifier_loss = class_criterion(source_logits, labels_source.long())
            # fisher loss on labeled source domain
            # fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(features.narrow(0, 0, int(inputs.size(0)/2)), labels_source, inter_class=loss_params["inter_type"], 
                                                                                   # intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))
            
            # final loss
            total_loss = loss_params["em_loss_coef"] * em_loss + classifier_loss
            #total_loss.backward() no backprop on the eval mode

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: validation total loss={:0.4f}, validation classifier loss={:0.4f}, '
                'validation entropy min loss={:0.4f}\n'.format(
                i, total_loss.data.cpu(), classifier_loss.data.cpu().float().item(), em_loss.data.cpu().float().item(), 
                ))
            config['out_file'].flush()
            writer.add_scalar("validation total loss", total_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation classifier loss", classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("validation entropy minimization loss", em_loss.data.cpu().float().item(), i)
            # writer.add_scalar("validation transfer loss", transfer_loss.data.cpu().float().item(), i)
            # writer.add_scalar("validation total fisher loss", fisher_loss.data.cpu().float().item(), i)
            # writer.add_scalar("validation intra-group fisher", fisher_intra_loss.data.cpu().float().item(), i)
            # writer.add_scalar("validation inter-group fisher", fisher_inter_loss.data.cpu().float().item(), i)
            
    return best_acc
Пример #6
0
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = loss_params["loss_type"](
        num_classes=class_num,
        feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    dsets = {}
    dset_loaders = {}

    #sampling WOR, i guess we leave the 10 in the middle to validate?
    pristine_indices = torch.randperm(len(pristine_x))
    #train
    pristine_x_train = pristine_x[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    pristine_y_train = pristine_y[
        pristine_indices[:int(np.floor(.7 * len(pristine_x)))]]
    #validate --- gets passed into test functions in train file
    pristine_x_valid = pristine_x[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    pristine_y_valid = pristine_y[pristine_indices[
        int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 *
                                                         len(pristine_x)))]]
    #test for evaluation file
    pristine_x_test = pristine_x[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]
    pristine_y_test = pristine_y[
        pristine_indices[int(np.floor(.8 * len(pristine_x))):]]

    noisy_indices = torch.randperm(len(noisy_x))
    #train
    noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]]
    #validate --- gets passed into test functions in train file
    noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x))
                                              ):int(np.floor(.8 *
                                                             len(noisy_x)))]]
    #test for evaluation file
    noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]
    noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]]

    dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train)
    dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train)

    dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid)
    dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid)

    dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test)
    dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test)

    #put your dataloaders here
    #i stole batch size numbers from below
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=128,
                                        shuffle=True,
                                        num_workers=1)

    #guessing batch size based on what was done for testing in the original file
    dset_loaders["source_valid"] = DataLoader(dsets["source_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)
    dset_loaders["target_valid"] = DataLoader(dsets["target_valid"],
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=1)

    dset_loaders["source_test"] = DataLoader(dsets["source_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)
    dset_loaders["target_test"] = DataLoader(dsets["target_test"],
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    config["num_iterations"] = len(
        dset_loaders["source"]) * config["epochs"] + 1
    config["test_interval"] = len(dset_loaders["source"])
    config["snapshot_interval"] = len(
        dset_loaders["source"]) * config["epochs"] * .25
    config["log_iter"] = len(dset_loaders["source"])

    #print the configuration you are using
    config["out_file"].write("config: {}\n".format(config))
    config["out_file"].flush()

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])  #,
    #max_iter_value=config["num_iterations"])
    if use_gpu:
        ad_net = ad_net.cuda()

        ## collect parameters
    if "DeepMerge" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    elif "ResNet18" in args.net:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": .1,
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })

    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
            parameter_list.append({
                "params": ad_net.parameters(),
                "lr_mult": config["ad_net_mult_lr"],
                'decay_mult': 2
            })
            parameter_list.append({
                "params": center_criterion.parameters(),
                "lr_mult": 10,
                'decay_mult': 1
            })
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
            parameter_list.append({
                "params": ad_net.parameters(),
                "lr_mult": config["ad_net_mult_lr"],
                'decay_mult': 2
            })
            parameter_list.append({
                "params": center_criterion.parameters(),
                "lr_mult": 10,
                'decay_mult': 1
            })
    else:
        parameter_list = [{
            "params": base_network.parameters(),
            "lr_mult": 1,
            'decay_mult': 2
        }]
        parameter_list.append({
            "params": ad_net.parameters(),
            "lr_mult": config["ad_net_mult_lr"],
            'decay_mult': 2
        })
        parameter_list.append({
            "params": center_criterion.parameters(),
            "lr_mult": 10,
            'decay_mult': 1
        })
    #Should I put lr_mult here as 1 for DeepMerge too? Probably!

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    scan_lr = []
    scan_loss = []

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    len_valid_source = len(dset_loaders["source_valid"])
    len_valid_target = len(dset_loaders["target_valid"])

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0

    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = image_classification_test(dset_loaders, 'source', \
                    base_network, gpu=use_gpu, verbose = False, save_where = None)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \
                    base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose = False, save_where = None)
                train_acc, _ = distance_classification_test(
                    dset_loaders,
                    'source',
                    base_network,
                    center_criterion.centers.detach(),
                    gpu=use_gpu,
                    verbose=False,
                    save_where=None)
            else:
                raise ValueError("no test method for cls loss: {}".format(
                    config['loss']['ly_type']))

            snapshot_obj = {
                'epoch': i / len(dset_loaders["source"]),
                "base_network": base_network.state_dict(),
                'valid accuracy': temp_acc,
                'train accuracy': train_acc,
            }

            if (i + 1) % config["snapshot_interval"] == 0:
                torch.save(
                    snapshot_obj,
                    osp.join(
                        config["output_path"], "epoch_{}_model.pth.tar".format(
                            i / len(dset_loaders["source"]))))

            if config["loss"]["loss_name"] != "laplacian" and config["loss"][
                    "ly_type"] == "euclidean":
                snapshot_obj['center_criterion'] = center_criterion.state_dict(
                )

            if temp_acc > best_acc:
                best_acc = temp_acc

                # save best model
                torch.save(
                    snapshot_obj,
                    osp.join(config["output_path"], "best_model.pth.tar"))

            log_str = "epoch: {}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format(
                i / len(dset_loaders["source"]), config['loss']['ly_type'],
                temp_acc, config['loss']['ly_type'], train_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("validation accuracy", temp_acc,
                              i / len(dset_loaders["source"]))
            writer.add_scalar("training accuracy", train_acc,
                              i / len(dset_loaders["source"]))

        ## train one iter
        base_network.train(True)

        if i % config["log_iter"] == 0:
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "one-cycle":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        if config["optimizer"]["lr_type"] == "linear":
            optimizer = lr_scheduler(param_lr, optimizer, i,
                                     config["log_iter"], config["frozen lr"],
                                     config["cycle_length"], **schedule_param)

        optim = optimizer.state_dict()
        scan_lr.append(optim['param_groups'][0]['lr'])

        optimizer.zero_grad()

        try:
            inputs_source, labels_source = iter(dset_loaders["source"]).next()
            inputs_target, labels_target = iter(dset_loaders["target"]).next()
        except StopIteration:
            iter(dset_loaders["source"])
            iter(dset_loaders["target"])

        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(
                features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                            weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source.long())
        # fisher loss on labeled source domain

        if config["fisher_or_no"] == 'no':
            total_loss = loss_params["trade_off"] * transfer_loss \
            + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            ######################################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 50 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     domain_classifier=ad_net,
                                     num_of_samples=100,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            ##########################################

            # if center_grad is not None:
            #     # clear mmc_loss
            #     center_criterion.centers.grad.zero_()
            #     # Manually assign centers gradients other than using autograd
            #     center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f},'
                    'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        ad_acc,
                        source_acc_ad,
                        target_acc_ad,
                    ))
                config['out_file'].flush()
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training source+target domain accuracy",
                                  ad_acc, i / len(dset_loaders["source"]))
                writer.add_scalar("training source domain accuracy",
                                  source_acc_ad,
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training target domain accuracy",
                                  target_acc_ad,
                                  i / len(dset_loaders["source"]))

                #attempted validation step
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        ad_net.train(False)
                        weight_ad = torch.ones(valid_inputs.size(0))
                        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                                           weight_ad, use_gpu)
                        ad_out, _ = ad_net(features.detach())
                        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(
                            ad_out)

                        # source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        #if config["fisher_or_no"] == 'no':
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                    + classifier_loss

                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f},'
                            'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                ad_acc,
                                source_acc_ad,
                                target_acc_ad,
                            ))
                        config['out_file'].flush()
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation source+target domain accuracy", ad_acc,
                            i / len(dset_loaders["source"]))
                        writer.add_scalar("validation source domain accuracy",
                                          source_acc_ad,
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar("validation target domain accuracy",
                                          target_acc_ad,
                                          i / len(dset_loaders["source"]))

                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

        else:
            fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                features.narrow(0, 0, int(inputs.size(0) / 2)),
                labels_source,
                inter_class=config["loss"]["inter_type"],
                intra_loss_weight=loss_params["intra_loss_coef"],
                inter_loss_weight=loss_params["inter_loss_coef"])
            # entropy minimization loss
            em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

            # final loss
            total_loss = loss_params["trade_off"] * transfer_loss \
                         + fisher_loss \
                         + loss_params["em_loss_coef"] * em_loss \
                         + classifier_loss

            scan_loss.append(total_loss.cpu().float().item())

            total_loss.backward()

            ######################################
            # Plot embeddings periodically.
            if args.blobs is not None and i / len(
                    dset_loaders["source"]) % 50 == 0:
                visualizePerformance(base_network,
                                     dset_loaders["source"],
                                     dset_loaders["target"],
                                     batch_size=128,
                                     num_of_samples=50,
                                     imgName='embedding_' +
                                     str(i / len(dset_loaders["source"])),
                                     save_dir=osp.join(config["output_path"],
                                                       "blobs"))
            ##########################################

            if center_grad is not None:
                # clear mmc_loss
                center_criterion.centers.grad.zero_()
                # Manually assign centers gradients other than using autograd
                center_criterion.centers.backward(center_grad)

            optimizer.step()

            if i % config["log_iter"] == 0:

                if config['lr_scan'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"],
                                     "learning_rate_scan")):
                        os.makedirs(
                            osp.join(config["output_path"],
                                     "learning_rate_scan"))

                    plot_learning_rate_scan(
                        scan_lr, scan_loss, i / len(dset_loaders["source"]),
                        osp.join(config["output_path"], "learning_rate_scan"))

                if config['grad_vis'] != 'no':
                    if not osp.exists(
                            osp.join(config["output_path"], "gradients")):
                        os.makedirs(
                            osp.join(config["output_path"], "gradients"))

                    plot_grad_flow(
                        osp.join(config["output_path"], "gradients"),
                        i / len(dset_loaders["source"]),
                        base_network.named_parameters())

                config['out_file'].write(
                    'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}, '
                    'train entropy min loss={:0.4f}, '
                    'train fisher loss={:0.4f}, train intra-group fisher loss={:0.4f}, train inter-group fisher loss={:0.4f}, '
                    'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n'
                    .format(
                        i / len(dset_loaders["source"]),
                        total_loss.data.cpu().float().item(),
                        transfer_loss.data.cpu().float().item(),
                        classifier_loss.data.cpu().float().item(),
                        em_loss.data.cpu().float().item(),
                        fisher_loss.cpu().float().item(),
                        fisher_intra_loss.cpu().float().item(),
                        fisher_inter_loss.cpu().float().item(),
                        ad_acc,
                        source_acc_ad,
                        target_acc_ad,
                    ))

                config['out_file'].flush()
                writer.add_scalar("training total loss",
                                  total_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training classifier loss",
                                  classifier_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training transfer loss",
                                  transfer_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training entropy minimization loss",
                                  em_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training total fisher loss",
                                  fisher_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training intra-group fisher",
                                  fisher_intra_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training inter-group fisher",
                                  fisher_inter_loss.data.cpu().float().item(),
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training source+target domain accuracy",
                                  ad_acc, i / len(dset_loaders["source"]))
                writer.add_scalar("training source domain accuracy",
                                  source_acc_ad,
                                  i / len(dset_loaders["source"]))
                writer.add_scalar("training target domain accuracy",
                                  target_acc_ad,
                                  i / len(dset_loaders["source"]))

                #attempted validation step
                for j in range(0, len(dset_loaders["source_valid"])):
                    base_network.train(False)
                    with torch.no_grad():

                        try:
                            inputs_valid_source, labels_valid_source = iter(
                                dset_loaders["source_valid"]).next()
                            inputs_valid_target, labels_valid_target = iter(
                                dset_loaders["target_valid"]).next()
                        except StopIteration:
                            iter(dset_loaders["source_valid"])
                            iter(dset_loaders["target_valid"])

                        if use_gpu:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = \
                                Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \
                                Variable(labels_valid_source).cuda()
                        else:
                            inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \
                                Variable(inputs_valid_target), Variable(labels_valid_source)

                        valid_inputs = torch.cat(
                            (inputs_valid_source, inputs_valid_target), dim=0)
                        valid_source_batch_size = inputs_valid_source.size(0)

                        if config['loss']['ly_type'] == 'cosine':
                            features, logits = base_network(valid_inputs)
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)
                        elif config['loss']['ly_type'] == 'euclidean':
                            features, _ = base_network(valid_inputs)
                            logits = -1.0 * loss.distance_to_centroids(
                                features, center_criterion.centers.detach())
                            source_logits = logits.narrow(
                                0, 0, valid_source_batch_size)

                        ad_net.train(False)
                        weight_ad = torch.ones(valid_inputs.size(0))
                        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                   weight_ad, use_gpu)
                        ad_out, _ = ad_net(features.detach())
                        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(
                            ad_out)

                        # source domain classification task loss
                        classifier_loss = class_criterion(
                            source_logits, labels_valid_source.long())

                        # fisher loss on labeled source domain
                        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(
                            features.narrow(0, 0,
                                            int(valid_inputs.size(0) / 2)),
                            labels_valid_source,
                            inter_class=loss_params["inter_type"],
                            intra_loss_weight=loss_params["intra_loss_coef"],
                            inter_loss_weight=loss_params["inter_loss_coef"])
                        # entropy minimization loss
                        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

                        # final loss
                        total_loss = loss_params["trade_off"] * transfer_loss \
                                     + fisher_loss \
                                     + loss_params["em_loss_coef"] * em_loss \
                                     + classifier_loss

                    if j % len(dset_loaders["source_valid"]) == 0:
                        config['out_file'].write(
                            'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}, '
                            'valid entropy min loss={:0.4f}, '
                            'valid fisher loss={:0.4f}, valid intra-group fisher loss={:0.4f}, valid inter-group fisher loss={:0.4f}, '
                            'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n'
                            .format(
                                i / len(dset_loaders["source"]),
                                total_loss.data.cpu().float().item(),
                                transfer_loss.data.cpu().float().item(),
                                classifier_loss.data.cpu().float().item(),
                                em_loss.data.cpu().float().item(),
                                fisher_loss.cpu().float().item(),
                                fisher_intra_loss.cpu().float().item(),
                                fisher_inter_loss.cpu().float().item(),
                                ad_acc,
                                source_acc_ad,
                                target_acc_ad,
                            ))

                        config['out_file'].flush()
                        writer.add_scalar("validation total loss",
                                          total_loss.data.cpu().float().item(),
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation classifier loss",
                            classifier_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation transfer loss",
                            transfer_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation entropy minimization loss",
                            em_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation total fisher loss",
                            fisher_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation intra-group fisher",
                            fisher_intra_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation inter-group fisher",
                            fisher_inter_loss.data.cpu().float().item(),
                            i / len(dset_loaders["source"]))
                        writer.add_scalar(
                            "validation source+target domain accuracy", ad_acc,
                            i / len(dset_loaders["source"]))
                        writer.add_scalar("validation source domain accuracy",
                                          source_acc_ad,
                                          i / len(dset_loaders["source"]))
                        writer.add_scalar("validation target domain accuracy",
                                          target_acc_ad,
                                          i / len(dset_loaders["source"]))

                        if early_stop_engine.is_stop_training(
                                classifier_loss.cpu().float().item()):
                            config["out_file"].write(
                                "no improvement after {}, stop training at step {}\n"
                                .format(config["early_stop_patience"],
                                        i / len(dset_loaders["source"])))

                            sys.exit()

    return best_acc
models = {
    'describer': Describer(categorical_dim),
    'translator': Translator(categorical_dim),
}

for model in models.values():
    model.to(device)
    model.apply(weights_init_normal)

optimizers = {
    'all': create_optimizer([
        models['describer'],
        models['translator'],
    ], lr)
}

losses = {
    'entropy': loss.EntropyLoss(categorical_dim),
    'classification': loss.ClassificationLoss(),
}

trainer = Trainer(dataloader, dataloader_subset, testloader, models,
                  optimizers, losses)
for _ in range(n_epochs):
    trainer.train()

save_list('data/losses.txt', trainer.losses)
save_list('data/accuracies.txt', trainer.accuracies)
save_list('data/misclassifications.txt', trainer.misclassifications)