def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu)
    save_root = root
    stage = args.stage

    num_domain = 4
    num_classes = 65

    if (args.save_root):
        save_root = args.save_root

    trg_ssl_train, trg_ssl_val = get_dataset(dataset=args.dataset,
                                             dataset_root=args.data_root,
                                             domain=args.trg_domain,
                                             ssl=True)
    src_ssl_train, src_ssl_val = get_dataset(dataset=args.dataset,
                                             dataset_root=args.data_root,
                                             domain=args.src_domain,
                                             ssl=True)

    model = get_model(args.model_name,
                      in_features=256,
                      num_classes=4,
                      num_domains=num_domain,
                      pretrained=True)
Example #2
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu)
    save_root = root
    stage = args.stage

    num_domain = 4
    num_classes = 65

    if args.dataset == 'domainnet':
        num_domain = 6
        num_classes = 345
    elif args.dataset == 'officehome':
        num_domain = 4
        num_classes = 65

    if (args.save_root):
        save_root = args.save_root

    trg_sup_train, trg_sup_val = get_dataset(dataset=args.dataset,
                                             dataset_root=args.data_root,
                                             domain=args.trg_domain,
                                             ssl=False)
    trg_num = domain_dict[args.dataset][args.trg_domain]
    src_train, src_val = get_dataset(dataset=args.dataset,
                                     dataset_root=args.data_root,
                                     domain=args.src_domain,
                                     ssl=False)
    src_num = domain_dict[args.dataset][args.src_domain]

    save_dir = None
    model = None

    #################################### STAGE 1 ####################################
    if stage == 1:

        if args.ssl:
            save_dir = join(save_root, 'stage1/rot/', args.trg_domain)
            if not os.path.isfile(join(save_dir, 'best_model.ckpt')):

                if not os.path.isdir(save_dir):
                    os.makedirs(save_dir, exist_ok=True)
                model = get_model(args.model_name,
                                  in_features=256,
                                  num_classes=4,
                                  num_domains=num_domain,
                                  pretrained=False)
                trg_ssl_train, trg_ssl_val = get_dataset(
                    dataset=args.dataset,
                    dataset_root=args.data_root,
                    domain=args.trg_domain,
                    # domain=[args.trg_domain, args.src_domain],
                    ssl=True)

                print('train stage 1')
                model = normal_train(args, model, trg_ssl_train, trg_ssl_val,
                                     args.iters[0], save_dir, args.trg_domain)
            else:
                print('find stage 1 model: ', save_dir)
        else:
            save_dir = join(save_root, 'stage1/sup/', args.trg_domain)
            if not os.path.isfile(join(save_dir, 'best_model.ckpt')):
                print('train stage 1')
                if not os.path.isdir(save_dir):
                    os.makedirs(save_dir, exist_ok=True)
                model = get_model(args.model_name,
                                  in_features=num_classes,
                                  num_classes=num_classes,
                                  num_domains=num_domain,
                                  pretrained=False)

                model = normal_train(args, model, trg_sup_train, trg_sup_val,
                                     args.iters[0], save_dir, args.trg_domain)
            else:
                print('find stage 1 model: ', save_dir)
        if args.only1:
            stage = 1
        else:
            stage += 1

    #################################### STAGE 2 ####################################
    if stage == 2:
        print('train stage 2')
        if args.ssl:
            model_pth = join(save_root, 'stage1/rot/', args.trg_domain,
                             'best_model.ckpt')
            print('load model from %s' % (model_pth))
            pre = torch.load(model_pth)
            save_dir = join(save_root, 'stage2/rot', args.save_dir)

        else:
            model_pth = join(save_root, 'stage1/sup/', args.trg_domain,
                             'best_model.ckpt')
            print('load model from %s' % (model_pth))
            pre = torch.load(model_pth)
            save_dir = join(save_root, 'stage2/sup', args.save_dir)

        # sys.stdout = open(join(save_dir, 'logs.txt'), 'w')
        model = get_model(args.model_name,
                          in_features=num_classes,
                          num_classes=num_classes,
                          num_domains=num_domain,
                          pretrained=False)
        model.load_state_dict(pre, strict=False)

        src_bn = 'bns.' + (str)(src_num)
        trg_bn = 'bns.' + (str)(trg_num)

        weight_dict = OrderedDict()
        for name, p in model.named_parameters():
            if (trg_bn in name):
                weight_dict[name] = p
                new_name = name.replace(trg_bn, src_bn)
                weight_dict[new_name] = p
            elif (src_bn in name):
                continue
            else:
                weight_dict[name] = p
        model.load_state_dict(weight_dict, strict=False)
        for name, p in model.named_parameters():
            if 'fc' in name:
                p.requires_grad = True
            else:
                p.requires_grad = False

        model.fc1.weight.requires_grad = True
        model.fc1.bias.requires_grad = True
        model.fc2.weight.requires_grad = True
        model.fc1.bias.requires_grad = True
        torch.nn.init.xavier_uniform_(model.fc1.weight)
        torch.nn.init.xavier_uniform_(model.fc2.weight)

        if not os.path.isdir(save_dir):
            os.makedirs(save_dir, exist_ok=True)

        model = normal_train(args,
                             model,
                             src_train,
                             src_val,
                             args.iters[1],
                             save_dir,
                             args.src_domain,
                             test_datset=trg_sup_val,
                             test_domain=args.trg_domain)
    #################################### STAGE 3 ####################################
    _, stage3_acc = test(args, model, trg_sup_val,
                         domain_dict[args.dataset][args.trg_domain])
    print('####################################')
    print('### stage 3 at stage1 iter: %0.3f' % (stage3_acc))
    print('####################################')
Example #3
0
def main(unused_arg):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_dir)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            samples, num_samples = get_dataset.get_dataset(
                FLAGS.dataset,
                FLAGS.dataset_dir,
                split_name=FLAGS.train_split,
                is_training=True,
                image_size=[FLAGS.image_size, FLAGS.image_size],
                batch_size=clone_batch_size,
                channel=FLAGS.input_channel)
            tf.logging.info('Training on %s set: %d', FLAGS.train_split,
                            num_samples)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)
        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()
            # Define the model and create clones.
            model_fn = _build_model
            model_args = (inputs_queue, clone_batch_size)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)
        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        # Add summaries for model variables.
        if FLAGS.save_summaries_variables:
            for model_var in slim.get_model_variables():
                summaries.add(
                    tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor, FLAGS.number_of_steps,
                FLAGS.learning_power, FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate)
            optimizer = tf.train.AdamOptimizer(learning_rate)
            #optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('losses/total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            if (FLAGS.dataset == 'protein') and FLAGS.add_counts_logits:
                last_layers = ['Logits', 'Counts_logits']
            else:
                last_layers = ['Logits']
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = True
        session_config.gpu_options.per_process_gpu_memory_fraction = 0.9

        # Start the training.
        slim.learning.train(train_tensor,
                            FLAGS.train_dir,
                            is_chief=(FLAGS.task == 0),
                            master=FLAGS.master,
                            graph=graph,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            number_of_steps=FLAGS.number_of_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_dir,
                                FLAGS.fine_tune_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            saver=tf.train.Saver(max_to_keep=50))
if __name__=="__main__":
    if len(sys.argv) == 4:
        dataset_name = sys.argv[1]
        tau = float(sys.argv[2])
        run_average = int(sys.argv[3])
    else:
        dataset_name = input("Enter dataset name: ")
        tau = float(input("Enter tau: "))
        run_average = int(input("Enter run average: "))
    # epsilon_schedule = [1]
    # epsilon_schedule = [0.001, 0.1, 0.3, 0.5, 0.7, 1]
    # epsilon_schedule = [0.001, 0.01, 0.1]
    epsilon_schedule = [0.01, 0.1]

    # X,y = get_dataset("sync_5,10000")
    X,y = get_dataset(dataset_name)
    N = X.shape[0]
    D = X.shape[1]

    # classifier = LogisticRegression(penalty='none', fit_intercept=False)
    classifier = LogisticRegression(C=100000000000000, fit_intercept=False)
    classifier.fit(X, y)
    predicted_labels = classifier.predict(X)
    print(predicted_labels)
    eq = np.equal(y, predicted_labels)
    eq = eq.astype(float)
    accuracy = np.mean(eq)
    print("Scikit-learn classifier got accuracy {0}".format(accuracy))

    print("N = ", N)
    print("D = ", D)
Example #5
0
def build_model():
  """Builds graph for model to train with rewrites for quantization.
  Returns:
    g: Graph with fake quantization ops and batch norm folding suitable for
    training quantized weights.
    train_tensor: Train op for execution during training.
  """
  g = tf.Graph()
  with g.as_default(), tf.device(
      tf.train.replica_device_setter(FLAGS.ps_tasks)):
    samples, _ = get_dataset.get_dataset(FLAGS.dataset, FLAGS.dataset_dir,
                                         split_name=FLAGS.train_split,
                                         is_training=True,
                                         image_size=[FLAGS.image_size, FLAGS.image_size],
                                         batch_size=FLAGS.batch_size,
                                         channel=FLAGS.input_channel)

    inputs = tf.identity(samples['image'], name='image')
    labels = tf.identity(samples['label'], name='label')
    model_options = common.ModelOptions(output_stride=FLAGS.output_stride)
    net, end_points = model.get_features(
        inputs,
        model_options=model_options,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
    logits, _ = model.classification(net, end_points, 
                                     num_classes=FLAGS.num_classes,
                                     is_training=True)
    logits = slim.softmax(logits)
    focal_loss_tensor = train_utils.focal_loss(labels, logits, weights=1.0)
    # f1_loss_tensor = train_utils.f1_loss(labels, logits, weights=1.0)
    # cls_loss = f1_loss_tensor
    cls_loss = focal_loss_tensor

    # Gather update_ops
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    global_step = tf.train.get_or_create_global_step()
    learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)
    # opt = tf.train.RMSPropOptimizer(learning_rate, momentum=FLAGS.momentum)
    summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    for loss in tf.get_collection(tf.GraphKeys.LOSSES):
      summaries.add(tf.summary.scalar('sub_losses/%s'%(loss.op.name), loss))
    classifation_loss = tf.identity(cls_loss, name='classifation_loss')
    summaries.add(tf.summary.scalar('losses/classifation_loss', classifation_loss))
    regularization_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    regularization_loss = tf.add_n(regularization_loss, name='regularization_loss')
    summaries.add(tf.summary.scalar('losses/regularization_loss', regularization_loss))

    total_loss = tf.add(cls_loss, regularization_loss, name='total_loss')
    grads_and_vars = opt.compute_gradients(total_loss)

    total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')
    summaries.add(tf.summary.scalar('losses/total_loss', total_loss))

    grad_updates = opt.apply_gradients(grads_and_vars, global_step=global_step)
    update_ops.append(grad_updates)
    update_op = tf.group(*update_ops, name='update_barrier')
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

  # Merge all summaries together.
  summary_op = tf.summary.merge(list(summaries))
  return g, train_tensor, summary_op
Example #6
0
def eval_model():
    """Evaluates model."""
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.gfile.MakeDirs(FLAGS.eval_dir)
    tf.logging.info('Evaluating on %s set', FLAGS.eval_split)
    g = tf.Graph()
    with g.as_default():
        samples, num_samples = get_dataset.get_dataset(
            FLAGS.dataset,
            FLAGS.dataset_dir,
            split_name=FLAGS.eval_split,
            is_training=False,
            image_size=[FLAGS.image_size, FLAGS.image_size],
            batch_size=FLAGS.batch_size,
            channel=FLAGS.input_channel)
        inputs = tf.identity(samples['image'], name='image')
        labels = tf.identity(samples['label'], name='label')
        model_options = common.ModelOptions(output_stride=FLAGS.output_stride)
        net, end_points = model.get_features(inputs,
                                             model_options=model_options,
                                             is_training=False,
                                             fine_tune_batch_norm=False)

        _, end_points = model.classification(net,
                                             end_points,
                                             num_classes=FLAGS.num_classes,
                                             is_training=False)
        eval_ops = metrics(end_points, labels)
        #num_samples = 1000
        num_batches = math.ceil(num_samples / float(FLAGS.batch_size))
        tf.logging.info('Eval num images %d', num_samples)
        tf.logging.info('Eval batch size %d and num batch %d',
                        FLAGS.batch_size, num_batches)
        # session_config = tf.ConfigProto(device_count={'GPU': 0})
        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        if FLAGS.use_slim:
            num_eval_iters = None
            if FLAGS.max_number_of_evaluations > 0:
                num_eval_iters = FLAGS.max_number_of_evaluations
            slim.evaluation.evaluation_loop(
                FLAGS.master,
                FLAGS.checkpoint_dir,
                logdir=FLAGS.eval_dir,
                num_evals=num_batches,
                eval_op=eval_ops,
                session_config=session_config,
                max_number_of_evaluations=num_eval_iters,
                eval_interval_secs=FLAGS.eval_interval_secs)
        else:
            with tf.Session(config=session_config) as sess:
                init_op = tf.group(tf.global_variables_initializer(),
                                   tf.local_variables_initializer())
                sess.run(init_op)
                saver_fn = get_checkpoint_init_fn(FLAGS.checkpoint_dir)
                saver_fn(sess)
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                try:
                    i = 0
                    all_pres = []
                    predictions_custom_list = []
                    all_labels = []
                    while not coord.should_stop():
                        logits_np, labels_np = sess.run(
                            [end_points['Logits_Predictions'], labels])
                        logits_np = logits_np[0]
                        labels_np = labels_np[0]
                        all_labels.append(labels_np)
                        labels_id = np.where(labels_np == 1)[0]
                        predictions_id = list(
                            np.where(logits_np > (_THRESHOULD))[0])
                        predictions_np = np.where(logits_np > (_THRESHOULD), 1,
                                                  0)
                        if np.sum(predictions_np) == 0:
                            max_id = np.argmax(logits_np)
                            predictions_np[max_id] = 1
                            predictions_id.append(max_id)
                        predictions_custom_list.append(predictions_np)
                        i += 1
                        sys.stdout.write(
                            'Image[{0}]--> labels:{1}, predictions: {2}\n'.
                            format(i, labels_id, predictions_id))
                        sys.stdout.flush()

                        predictions_image_list = []
                        for thre in range(1, FLAGS.threshould, 1):
                            predictions_id = list(
                                np.where(logits_np > (thre / 100000000))[0])
                            predictions_np = np.where(
                                logits_np > (thre / 100000000), 1, 0)
                            if np.sum(predictions_np) == 0:
                                max_id = np.argmax(logits_np)
                                predictions_np[max_id] = 1
                                predictions_id.append(max_id)
                            predictions_image_list.append(predictions_np)
                        all_pres.append(predictions_image_list)
                except tf.errors.OutOfRangeError:
                    coord.request_stop()
                    coord.join(threads)
                finally:
                    sys.stdout.write('\n')
                    sys.stdout.flush()
                    pred_rows = []
                    all_labels = np.stack(all_labels, 0)
                    pres_custom = np.stack(predictions_custom_list, 0)
                    eval_custom = metric_eval(all_labels, pres_custom)
                    sys.stdout.write(
                        'Eval[f1_score, precision, recall]: {}\n'.format(
                            eval_custom['All']))
                    sys.stdout.flush()
                    pred_rows.append(eval_custom)
                    all_pres = np.transpose(all_pres, (1, 0, 2))
                    for pre, thre in zip(all_pres,
                                         range(1, FLAGS.threshould, 1)):
                        pred_rows.append(metric_eval(all_labels, pre, thre))
                    columns = ['Thre'] + list(
                        PROTEIN_CLASS_NAMES.values()) + ['All']
                    submission_df = pd.DataFrame(pred_rows)[columns]
                    submission_df.to_csv(os.path.join('./result/protein',
                                                      'protein_eval.csv'),
                                         index=False)
Example #7
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu)
    save_root = root
    if (args.save_root):
        save_root = args.save_root
    stage = args.stage

    num_domain = 4
    num_classes = 65

    if args.dataset == 'domainnet':
        num_domain = 6
        num_classes = 345
    elif args.dataset == 'officehome':
        num_domain = 4
        num_classes = 65

    ### 1. train encoder with rotation task ###
    save_dir = join(save_root, args.save_dir, 'stage1')
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir, exist_ok=True)

    if (stage == 1):
        train_dataset, val_dataset = get_dataset(dataset=args.dataset,
                                                 dataset_root=args.data_root,
                                                 domain=args.domain,
                                                 ssl=True)

        model = get_model(args.model_name,
                          in_features=256,
                          num_classes=4,
                          num_domains=num_domain,
                          pretrained=False)
        # model = get_rot_model(args.model_name, num_domains=6)
        # model = normal_train(args, model, train_dataset, val_dataset, args.iters[0], save_dir, args.domain,
        #                      save_model=True)

        stage += 1

    ### 2. train classifier with classification task ###
    if (stage == 2):
        train_dataset, val_dataset = get_dataset(dataset=args.dataset,
                                                 dataset_root=args.data_root,
                                                 domain=args.domain,
                                                 ssl=False)

        # for i in range(4):
        #
        #     iter = i * 20000 + 10000
        #     # iter = i * 2 + 1
        #     model_pth = join(save_dir, '%d_weight.ckpt' % (iter))
        #     if(os.path.isfile(model_pth)):
        #         pre = torch.load(model_pth)
        #     else:
        #         print('no weight exists: ', model_pth)
        #         break
        #     print('load weight: ', join(save_dir, '%d_weight.ckpt' % (iter)))
        #     model = get_model(args.model_name, in_features=num_classes, num_classes=num_classes,
        #                        num_domains=num_domain, pretrained=True)
        #
        #     new_pre = OrderedDict()
        #     for key in pre.keys():
        #         if 'fc' in key:
        #             print(key)
        #         else:
        #             new_pre[key] = pre[key]
        #
        #     model.load_state_dict(new_pre, strict=False)
        #
        #     torch.nn.init.xavier_uniform_(model.fc1.weight)
        #     torch.nn.init.xavier_uniform_(model.fc2.weight)
        #     model.fc1.weight.requires_grad = True
        #     model.fc2.weight.requires_grad = True
        #
        #     save_dir_iter = join(save_root, args.save_dir, 'stage2_%d' % (iter))
        #     if not os.path.isdir(save_dir_iter):
        #         os.makedirs(save_dir_iter, exist_ok=True)
        #
        #     model = normal_train(args, model, train_dataset, val_dataset, args.iters[1], save_dir_iter, args.domain)

        model = get_model(args.model_name,
                          in_features=num_classes,
                          num_classes=num_classes,
                          num_domains=num_domain,
                          pretrained=False)

        # pre = torch.load(args.model_path)
        # new_pre = OrderedDict()
        # for key in pre.keys():
        #     if 'fc' in key:
        #         print(key)
        #     else:
        #         new_pre[key] = pre[key]
        #
        # model.load_state_dict(new_pre, strict=False)

        src_bn = 'bns.' + (str)(0)
        trg_bn = 'bns.' + (str)(1)

        weight_dict = OrderedDict()
        for name, p in model.named_parameters():
            if (trg_bn in name):
                weight_dict[name] = p
                new_name = name.replace(trg_bn, src_bn)
                weight_dict[new_name] = p
            elif (src_bn in name):
                continue
            else:
                weight_dict[name] = p
        model.load_state_dict(weight_dict, strict=False)
        for name, p in model.named_parameters():
            p.requires_grad = False

        model.fc1.weight.requires_grad = True
        model.fc1.bias.requires_grad = True
        model.fc2.weight.requires_grad = True
        model.fc2.bias.requires_grad = True
        torch.nn.init.xavier_uniform_(model.fc1.weight)
        torch.nn.init.xavier_uniform_(model.fc2.weight)

        # if args.onlyfc:
        #     print('train only fc layer')
        #     for name, p in model.named_parameters():
        #         p.requires_grad = False
        #
        # torch.nn.init.xavier_uniform_(model.fc1.weight)
        # torch.nn.init.xavier_uniform_(model.fc2.weight)
        # model.fc1.weight.requires_grad = True
        # model.fc2.weight.requires_grad = True

        save_dir = join(save_root, args.save_dir, 'stage2')
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir, exist_ok=True)

        model = normal_train(args, model, train_dataset, val_dataset,
                             args.iters[1], save_dir, args.domain)