def main():
    """Create the model and start the training."""
    model_num = 0

    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    random.seed(args.random_seed)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        if args.training_option == 1:
            model = Res_Deeplab(num_classes=args.num_classes,
                                num_layers=args.num_layers)
        elif args.training_option == 2:
            model = Res_Deeplab2(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()

        for k, v in saved_state_dict.items():
            print(k)

        for k in new_params:
            print(k)

        for i in saved_state_dict:
            i_parts = i.split('.')

            if '.'.join(i_parts[args.i_parts_index:]) in new_params:
                print("Restored...")
                if args.not_restore_last == True:
                    if not i_parts[
                            args.i_parts_index] == 'layer5' and not i_parts[
                                args.i_parts_index] == 'layer6':
                        new_params['.'.join(i_parts[args.i_parts_index:]
                                            )] = saved_state_dict[i]
                else:
                    new_params['.'.join(
                        i_parts[args.i_parts_index:])] = saved_state_dict[i]

        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    writer = SummaryWriter(log_dir=args.snapshot_dir)

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    '''trainloader = data.DataLoader(sourceDataSet(args.data_dir, 
                                                    args.data_list, 
                                                    max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                    crop_size=input_size,
                                                    scale=args.random_scale, 
                                                    mirror=args.random_mirror, 
                                                    mean=IMG_MEAN_SOURCE,
                                                    ignore_label=args.ignore_label),
                                  batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)'''

    trainloader = data.DataLoader(sourceDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        random_rotate=args.augment_1,
        random_flip=args.augment_1,
        random_lighting=args.augment_1,
        mean=IMG_MEAN_SOURCE,
        ignore_label=args.ignore_label),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(isprsDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN_TARGET,
        ignore_label=args.ignore_label),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    valloader = data.DataLoader(valDataSet(args.data_dir_val,
                                           args.data_list_val,
                                           crop_size=input_size_target,
                                           mean=IMG_MEAN_TARGET,
                                           scale=False,
                                           mirror=False),
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # EDITTED by me
    interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[0],
                                      input_size_target[1]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # Which layers to freeze
    non_trainable(args.dont_train, model)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            while True:
                try:

                    _, batch = next(trainloader_iter)
                    images, labels, _, train_name = batch
                    #print(train_name)
                    images = Variable(images).cuda(args.gpu)

                    pred1, pred2 = model(images)
                    pred1 = interp(pred1)
                    pred2 = interp(pred2)

                    # Save img
                    '''if i_iter % 5 == 0:
                        save_image_for_test(concatenate_side_by_side([images, labels, pred2]), i_iter)'''

                    loss_seg1 = loss_calc(pred1, labels, args.gpu,
                                          args.ignore_label, train_name)
                    loss_seg2 = loss_calc(pred2, labels, args.gpu,
                                          args.ignore_label, train_name)

                    loss = loss_seg2 + args.lambda_seg * loss_seg1

                    # proper normalization
                    loss = loss / args.iter_size
                    loss.backward()

                    if isinstance(loss_seg1.data.cpu().numpy(), list):
                        loss_seg_value1 += loss_seg1.data.cpu().numpy(
                        )[0] / args.iter_size
                    else:
                        loss_seg_value1 += loss_seg1.data.cpu().numpy(
                        ) / args.iter_size

                    if isinstance(loss_seg2.data.cpu().numpy(), list):
                        loss_seg_value2 += loss_seg2.data.cpu().numpy(
                        )[0] / args.iter_size
                    else:
                        loss_seg_value2 += loss_seg2.data.cpu().numpy(
                        ) / args.iter_size
                    break
                except (RuntimeError, AssertionError, AttributeError):
                    continue

            if args.experiment == 1:
                # Which layers to freeze
                non_trainable('0', model)

            # train with target
            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            #total_image2 = vutils.make_grid(torch.cat((images.cuda()), dim = 2),normalize=True, scale_each=True)
            #total_image2 = images.cuda()
            #, pred_target1.cuda(), pred_target2.cuda()
            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()

            if isinstance(loss_adv_target1.data.cpu().numpy(), list):
                loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
                )[0] / args.iter_size
            else:
                loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
                ) / args.iter_size

            if isinstance(loss_adv_target2.data.cpu().numpy(), list):
                loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
                )[0] / args.iter_size
            else:
                loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
                ) / args.iter_size

            if args.experiment == 1:
                # Which layers to freeze
                non_trainable(args.dont_train, model)

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            if isinstance(loss_D1.data.cpu().numpy(), list):
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else:
                loss_D_value1 += loss_D1.data.cpu().numpy()

            if isinstance(loss_D2.data.cpu().numpy(), list):
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else:
                loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            if isinstance(loss_D1.data.cpu().numpy(), list):
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else:
                loss_D_value1 += loss_D1.data.cpu().numpy()

            if isinstance(loss_D2.data.cpu().numpy(), list):
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else:
                loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            #print ('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'model_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'model_' + str(args.num_steps) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'model_' + str(args.num_steps) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            #print ('taking snapshot ...')
            '''torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'model_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_' + str(i_iter) + '_D2.pth'))'''
            if model_num != args.num_models_keep:
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '_D2.pth'))
                model_num = model_num + 1
            if model_num == args.num_models_keep:
                model_num = 0

        # Validation
        if (i_iter % args.val_every == 0 and i_iter != 0) or i_iter == 1:
            validation(valloader, model, interp_target, writer, i_iter,
                       [37, 41, 10])

        # Save for tensorboardx
        writer.add_scalar('loss_seg_value1', loss_seg_value1, i_iter)
        writer.add_scalar('loss_seg_value2', loss_seg_value2, i_iter)
        writer.add_scalar('loss_adv_target_value1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar('loss_adv_target_value2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar('loss_D_value1', loss_D_value1, i_iter)
        writer.add_scalar('loss_D_value2', loss_D_value2, i_iter)

    writer.close()
Exemplo n.º 2
0
def main():
    """Create the model and start the training."""
    model_num = 0  # The number of model (for saving models)

    torch.manual_seed(args.random_seed)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    writer = SummaryWriter(log_dir=args.snapshot_dir)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    cudnn.benchmark = True

    # init G
    if args.model == 'DeepLab':

        model = Res_Deeplab(num_classes=args.num_classes)

        saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()

        for i in saved_state_dict:
            i_parts = i.split('.')
            if args.not_restore_last == True:
                if not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            else:
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    trainloader2 = data.DataLoader(sourceDataSet(
        args.data_dir2,
        args.data_list2,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        random_rotate=False,
        random_flip=args.augment_2,
        random_lighting=args.augment_2,
        random_blur=args.augment_2,
        random_scaling=args.augment_2,
        mean=IMG_MEAN_SOURCE2,
        ignore_label=args.ignore_label,
        source=args.source),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    trainloader_iter2 = enumerate(trainloader2)

    valloader = data.DataLoader(valDataSet(args.data_dir_val,
                                           args.data_list_val,
                                           crop_size=input_size,
                                           mean=IMG_MEAN_SOURCE2,
                                           mirror=False,
                                           source=args.source),
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True)

    optimizer = optim.SGD([{
        'params': get_1x_lr_params_NOscale(model),
        'lr': args.learning_rate
    }, {
        'params': get_10x_lr_params(model),
        'lr': 10 * args.learning_rate
    }],
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear')

    # List saving all best 5 mIoU's
    best_mIoUs = [0.0, 0.0, 0.0, 0.0, 0.0]

    for i_iter in range(args.num_steps):
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        while True:
            try:

                _, batch = next(trainloader_iter2)
                images2, labels2, _, train_name2 = batch
                images2 = Variable(images2).cuda(args.gpu)

                pred2 = model(images2)
                pred2 = interp(pred2)
                print(pred2)
                print(labels2)

                print(pred2.size())
                print(labels2.size())

                loss = loss_calc(pred2, labels2, args.gpu, args.ignore_label,
                                 train_name2)
                loss.backward()

                break
            except (RuntimeError, AssertionError, AttributeError):
                continue
        print('Iter ...')
        optimizer.step()

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            #print ('taking snapshot ...')
            if model_num != args.num_models_keep:
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '.pth'))
                model_num = model_num + 1
            if model_num == args.num_models_keep:
                model_num = 0

        # Validation
        if (i_iter % args.val_every == 0 and i_iter != 0) or i_iter == 1:
            mIoU = validation(valloader, model, interp, writer, i_iter,
                              [37, 41, 10])
            for i in range(0, len(best_mIoUs)):
                if best_mIoUs[i] < mIoU:
                    torch.save(
                        model.state_dict(),
                        osp.join(args.snapshot_dir,
                                 'bestmodel_' + str(i) + '.pth'))
                    best_mIoUs.append(mIoU)
                    print("Saved model at iteration %d as the best %d" %
                          (i_iter, i))
                    best_mIoUs.sort(reverse=True)
                    best_mIoUs = best_mIoUs[:5]
                    break

        # Save for tensorboardx
        writer.add_scalar('loss', loss, i_iter)

    writer.close()
Exemplo n.º 3
0
def main():
    """Create the model and start the training."""
    model_num = 0  # The number of model (for saving models)

    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    random.seed(args.random_seed)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    writer = SummaryWriter(log_dir=args.snapshot_dir)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    cudnn.benchmark = True

    # init G
    if args.model == 'DeepLab':
        if args.training_option == 1:
            model = Res_Deeplab(num_classes=args.num_classes,
                                num_layers=args.num_layers,
                                dropout=args.dropout,
                                after_layer=args.after_layer)
        elif args.training_option == 2:
            model = Res_Deeplab2(num_classes=args.num_classes)
        '''elif args.training_option == 3:
            model = Res_Deeplab50(num_classes=args.num_classes)'''

        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()

        for k, v in saved_state_dict.items():
            print(k)

        for k in new_params:
            print(k)

        for i in saved_state_dict:
            i_parts = i.split('.')
            if '.'.join(i_parts[args.i_parts_index:]) in new_params:
                print("Restored...")
                if args.not_restore_last == True:
                    if not i_parts[
                            args.i_parts_index] == 'layer5' and not i_parts[
                                args.i_parts_index] == 'layer6':
                        new_params['.'.join(i_parts[args.i_parts_index:]
                                            )] = saved_state_dict[i]
                else:
                    new_params['.'.join(
                        i_parts[args.i_parts_index:])] = saved_state_dict[i]

        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes,
                               extra_layers=args.extra_discriminator_layers)
    model_D2 = FCDiscriminator(num_classes=args.num_classes, extra_layers=0)

    model_D1.train()
    model_D1.cuda(args.gpu)
    model_D2.train()
    model_D2.cuda(args.gpu)

    trainloader = data.DataLoader(sourceDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        random_rotate=False,
        random_flip=args.augment_1,
        random_lighting=args.augment_1,
        random_blur=args.augment_1,
        random_scaling=args.augment_1,
        mean=IMG_MEAN_SOURCE,
        ignore_label=args.ignore_label),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    trainloader2 = data.DataLoader(sourceDataSet(
        args.data_dir2,
        args.data_list2,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        random_rotate=False,
        random_flip=args.augment_2,
        random_lighting=args.augment_2,
        random_blur=args.augment_2,
        random_scaling=args.augment_2,
        mean=IMG_MEAN_SOURCE2,
        ignore_label=args.ignore_label),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    trainloader_iter2 = enumerate(trainloader2)

    if args.num_of_targets > 1:

        IMG_MEAN_TARGET1 = np.array(
            (101.41694189393208, 89.68194541655483, 77.79408426901315),
            dtype=np.float32)  # crowdai all BGR

        targetloader1 = data.DataLoader(isprsDataSet(
            args.data_dir_target1,
            args.data_list_target1,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            scale=False,
            mean=IMG_MEAN_TARGET1,
            ignore_label=args.ignore_label),
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers,
                                        pin_memory=True)

        targetloader_iter1 = enumerate(targetloader1)

    targetloader = data.DataLoader(isprsDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mean=IMG_MEAN_TARGET,
        ignore_label=args.ignore_label),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    valloader = data.DataLoader(valDataSet(args.data_dir_val,
                                           args.data_list_val,
                                           crop_size=input_size_target,
                                           mean=IMG_MEAN_TARGET,
                                           scale=args.val_scale,
                                           mirror=False),
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True)

    # implement model.optim_parameters(args) to handle different models' lr setting
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()
    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.weighted_loss == True:
        bce_loss = torch.nn.BCEWithLogitsLoss()
    else:
        bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[0],
                                      input_size_target[1]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # Which layers to freeze
    non_trainable(args.dont_train, model)

    # List saving all best 5 mIoU's
    best_mIoUs = [0.0, 0.0, 0.0, 0.0, 0.0]

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            ################################## train G #################################

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False
            for param in model_D2.parameters():
                param.requires_grad = False

            ################################## train with source #################################

            while True:
                try:
                    _, batch = next(
                        trainloader_iter)  # Cityscapes, only discriminator1
                    images, labels, _, train_name = batch
                    images = Variable(images).cuda(args.gpu)

                    _, batch = next(
                        trainloader_iter2
                    )  # Main (airsim) discriminator2 and final output
                    images2, labels2, size, train_name2 = batch
                    images2 = Variable(images2).cuda(args.gpu)

                    pred1, _ = model(images)
                    pred1 = interp(pred1)

                    _, pred2 = model(images2)
                    pred2 = interp(pred2)

                    loss_seg1 = loss_calc(pred1, labels, args.gpu,
                                          args.ignore_label, train_name,
                                          weights1)
                    loss_seg2 = loss_calc(pred2, labels2, args.gpu,
                                          args.ignore_label, train_name2,
                                          weights2)

                    loss = loss_seg2 + args.lambda_seg * loss_seg1

                    # proper normalization
                    loss = loss / args.iter_size
                    loss.backward()

                    if isinstance(loss_seg1.data.cpu().numpy(), list):
                        loss_seg_value1 += loss_seg1.data.cpu().numpy(
                        )[0] / args.iter_size
                    else:
                        loss_seg_value1 += loss_seg1.data.cpu().numpy(
                        ) / args.iter_size

                    if isinstance(loss_seg2.data.cpu().numpy(), list):
                        loss_seg_value2 += loss_seg2.data.cpu().numpy(
                        )[0] / args.iter_size
                    else:
                        loss_seg_value2 += loss_seg2.data.cpu().numpy(
                        ) / args.iter_size

                    break
                except (RuntimeError, AssertionError, AttributeError):
                    continue

            ###################################################################################################

            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)

            if args.num_of_targets > 1:
                _, batch1 = next(targetloader_iter1)
                images1, _, _ = batch1
                images1 = Variable(images1).cuda(args.gpu)

                pred_target1, _ = model(images1)

            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            ################################## train with target #################################
            if args.adv_option == 1 or args.adv_option == 3:

                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_adv_target1 = bce_loss(
                    D_out1,
                    Variable(
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(source_label)).cuda(
                                args.gpu))

                loss_adv_target2 = bce_loss(
                    D_out2,
                    Variable(
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(source_label)).cuda(
                                args.gpu))

                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                loss = loss / args.iter_size
                loss.backward()

                if isinstance(loss_adv_target1.data.cpu().numpy(), list):
                    loss_adv_target_value1 += loss_adv_target1.data.cpu(
                    ).numpy()[0] / args.iter_size
                else:
                    loss_adv_target_value1 += loss_adv_target1.data.cpu(
                    ).numpy() / args.iter_size

                if isinstance(loss_adv_target2.data.cpu().numpy(), list):
                    loss_adv_target_value2 += loss_adv_target2.data.cpu(
                    ).numpy()[0] / args.iter_size
                else:
                    loss_adv_target_value2 += loss_adv_target2.data.cpu(
                    ).numpy() / args.iter_size

            ###################################################################################################
            if args.adv_option == 2 or args.adv_option == 3:
                pred1, _ = model(images)
                pred1 = interp(pred1)
                _, pred2 = model(images2)
                pred2 = interp(pred2)
                '''pred1 = pred1.detach()
                pred2 = pred2.detach()'''

                D_out1 = model_D1(F.softmax(pred1, dim=1))
                D_out2 = model_D2(F.softmax(pred2, dim=1))

                loss_adv_target1 = bce_loss(
                    D_out1,
                    Variable(
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(target_label)).cuda(
                                args.gpu))

                loss_adv_target2 = bce_loss(
                    D_out2,
                    Variable(
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(target_label)).cuda(
                                args.gpu))

                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                loss = loss / args.iter_size
                loss.backward()

                if isinstance(loss_adv_target1.data.cpu().numpy(), list):
                    loss_adv_target_value1 += loss_adv_target1.data.cpu(
                    ).numpy()[0] / args.iter_size
                else:
                    loss_adv_target_value1 += loss_adv_target1.data.cpu(
                    ).numpy() / args.iter_size

                if isinstance(loss_adv_target2.data.cpu().numpy(), list):
                    loss_adv_target_value2 += loss_adv_target2.data.cpu(
                    ).numpy()[0] / args.iter_size
                else:
                    loss_adv_target_value2 += loss_adv_target2.data.cpu(
                    ).numpy() / args.iter_size

            ###################################################################################################
            ################################## train D #################################

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            ################################## train with source #################################
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            if isinstance(loss_D1.data.cpu().numpy(), list):
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else:
                loss_D_value1 += loss_D1.data.cpu().numpy()

            if isinstance(loss_D2.data.cpu().numpy(), list):
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else:
                loss_D_value2 += loss_D2.data.cpu().numpy()

            ################################# train with target #################################

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            if isinstance(loss_D1.data.cpu().numpy(), list):
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else:
                loss_D_value1 += loss_D1.data.cpu().numpy()

            if isinstance(loss_D2.data.cpu().numpy(), list):
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else:
                loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            if model_num != args.num_models_keep:
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '_D2.pth'))
                model_num = model_num + 1
            if model_num == args.num_models_keep:
                model_num = 0

        # Validation
        if (i_iter % args.val_every == 0 and i_iter != 0) or i_iter == 1:
            mIoU = validation(valloader, model, interp_target, writer, i_iter,
                              [37, 41, 10])
            for i in range(0, len(best_mIoUs)):
                if best_mIoUs[i] < mIoU:
                    torch.save(
                        model.state_dict(),
                        osp.join(args.snapshot_dir,
                                 'bestmodel_' + str(i) + '.pth'))
                    torch.save(
                        model_D1.state_dict(),
                        osp.join(args.snapshot_dir,
                                 'bestmodel_' + str(i) + '_D1.pth'))
                    torch.save(
                        model_D2.state_dict(),
                        osp.join(args.snapshot_dir,
                                 'bestmodel_' + str(i) + '_D2.pth'))
                    best_mIoUs.append(mIoU)
                    print("Saved model at iteration %d as the best %d" %
                          (i_iter, i))
                    best_mIoUs.sort(reverse=True)
                    best_mIoUs = best_mIoUs[:5]
                    break

        # Save for tensorboardx
        writer.add_scalar('loss_seg_value1', loss_seg_value1, i_iter)
        writer.add_scalar('loss_seg_value2', loss_seg_value2, i_iter)
        writer.add_scalar('loss_adv_target_value1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar('loss_adv_target_value2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar('loss_D_value1', loss_D_value1, i_iter)
        writer.add_scalar('loss_D_value2', loss_D_value2, i_iter)

    writer.close()
Exemplo n.º 4
0
def main():
    """Create the model and start the training."""
    # start logger
    sys.stdout = Logger(stream=sys.stdout)
    sys.stderr = Logger(stream=sys.stderr)

    usecuda = True
    cudnn.enabled = True
    args = get_arguments()

    # makedatalist
    ml.makedatalist(args.data_dir_img, args.data_list)
    ml.makedatalist(args.data_dir_target, args.data_list_target)
    ml.makedatalist(args.data_dir_val, args.data_list_val)

    # setting logging directory
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    """
        load the data
    """
    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    trainloader = data.DataLoader(sourceDataSet(args.data_dir_img,
                                                args.data_dir_label,
                                                args.data_list,
                                                max_iters=args.num_steps,
                                                crop_size=input_size),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(targetDataSet(args.data_dir_target,
                                                 args.data_dir_target_label,
                                                 args.data_list_target,
                                                 max_iters=args.num_steps,
                                                 crop_size=input_size_target),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers)

    targetloader_iter = enumerate(targetloader)

    valloader = data.DataLoader(targetDataSet_val(args.data_dir_val,
                                                  args.data_dir_val_label,
                                                  args.data_list_val,
                                                  crop_size=input_size_target),
                                batch_size=1,
                                shuffle=False)
    """
        build the network
    """
    model = source2targetNet(in_channels=1, out_channels=2)

    model_label = labelDiscriminator(num_classes=args.num_classes)

    input_channels = 64
    level = 1
    model_feature = featureDiscriminator(input_channels=input_channels,
                                         input_size=w / (2**(level - 1)),
                                         num_classes=args.num_classes,
                                         fc_classifier=3)

    model.train()
    model_label.train()
    model_feature.train()

    if usecuda:
        cudnn.benchmark = True
        model.cuda(args.gpu)
        model_label.cuda(args.gpu)
        model_feature.cuda(args.gpu)
    """
        Loading the pretrain model
    """
    if args.pretrain == 1:
        old_model = torch.load(args.restore_from,
                               map_location='cuda:' + str(args.gpu))

        model_encoder_dict = model.encoder.state_dict()
        model_domain_decoder_dict = model.domain_decoder.state_dict()

        pretrained_dict = old_model.module.state_dict()

        frozen_layer = args.frozen_layer
        random_layer = args.random_layer
        for k, v in pretrained_dict.items():
            flag_random = [
                True for pattern in random_layer
                if re.search(pattern, k) is not None
            ]
            flag_frozen = [
                True for pattern in frozen_layer
                if re.search(pattern, k) is not None
            ]
            if len(flag_frozen) != 0:
                v.requires_grad = False
                print('frozen layer: %s ' % k)
            if len(flag_random) == 0:
                if k in model_encoder_dict:
                    model_encoder_dict[k] = v
                    print(k)
                if k in model_domain_decoder_dict:
                    print(k)
        model.encoder.load_state_dict(model_encoder_dict)
        model.domain_decoder.load_state_dict(model_domain_decoder_dict)
        print('copy pretrain layer finish!')

    if args.iter_start != 0:
        args.restore_from = args.snapshots_dir + str(args.iter_start) + '.pth'
        args.Dlabelrestore_from = args.snapshots_dir + str(
            args.iter_start) + '_D.pth'
        args.Dfeaturerestore_from = args.snapshots_dir + str(
            args.iter_start) + '_D2.pth'
        model.load_state_dict(torch.load(args.restore_from))
        model_label.load_state_dict(torch.load(args.Dlabelrestore_from))
        model_feature.load_state_dict(torch.load(args.Dfeaturerestore_from))
        print('load old model')
    """
        Setup optimization for training
    """
    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           betas=(0.9, 0.99))

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=args.step_size,
                                                gamma=0.9)

    optimizer.zero_grad()

    optimizer_label = optim.Adam(model_label.parameters(),
                                 lr=args.learning_rate_Dl,
                                 betas=(0.9, 0.99))
    scheduler_label = torch.optim.lr_scheduler.StepLR(
        optimizer_label, step_size=args.step_size_Dl, gamma=0.9)

    optimizer_label.zero_grad()

    optimizer_feature = optim.Adam(model_feature.parameters(),
                                   lr=args.learning_rate_Df,
                                   betas=(0.9, 0.99))
    scheduler_feature = torch.optim.lr_scheduler.StepLR(
        optimizer_feature, step_size=args.step_size_Df, gamma=0.9)

    optimizer_feature.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    entropy_loss = torch.nn.CrossEntropyLoss()
    rec_loss = MSELoss()

    # labels for adversarial training
    source_label = 1
    target_label = 0

    for i_iter in range(args.iter_start, args.num_steps):

        loss_seg_value = 0
        loss_adv_label_value = 0
        loss_adv_feature_value = 0

        loss_Dlabel_value = 0
        loss_Dfeature_value = 0
        loss_rec_value = 0
        optimizer.zero_grad()
        optimizer_label.zero_grad()
        optimizer_feature.zero_grad()

        # train G
        for param in model_label.parameters():
            param.requires_grad = False

        for param in model_feature.parameters():
            param.requires_grad = False

        # train with source
        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        if usecuda:
            images_source = Variable(images).cuda(args.gpu)
        else:
            images_source = Variable(images)

        recimg_source, feature_source, pred_source = model(images_source)

        loss_seg = loss_calc(pred_source, labels, args.gpu, usecuda)
        loss_seg_value += loss_seg.data.cpu().numpy()

        loss_rec_source = rec_loss(recimg_source, images_source)

        label = torch.argmax(pred_source, dim=1).float()
        sdice, sjac = dice_coeff(label.cpu(), labels.cpu())

        _, batch = targetloader_iter.__next__()
        images, tlabels, _, _ = batch
        if usecuda:
            images_target = Variable(images).cuda(args.gpu)
        else:
            images_target = Variable(images)

        recimg_target, feature_target, pred_target = model(images_target)

        loss_rec_target = rec_loss(recimg_target, images_target)

        loss_rec = (loss_rec_source + loss_rec_target) / 2
        loss_rec_value += loss_rec.data.cpu().numpy()

        loss = loss_seg + args.lambda_rec * loss_rec
        loss.backward()

        # Target Domain Adv loss
        # acc the target domain adv loss
        _, feature_target, pred_target = model(images_target)

        Dlabel_out = model_label(F.softmax(pred_target, dim=1))
        if usecuda:
            adv_source_label = Variable(
                torch.FloatTensor(
                    Dlabel_out.data.size()).fill_(source_label).cuda(args.gpu))
        else:
            adv_source_label = Variable(
                torch.FloatTensor(Dlabel_out.data.size()).fill_(source_label))
        loss_adv_label = bce_loss(Dlabel_out, adv_source_label)

        Dfeature_out = model_feature(feature_target)
        if usecuda:
            adv_source_label = Variable(
                torch.LongTensor(
                    Dfeature_out.size(0)).fill_(source_label).cuda(args.gpu))
        else:
            adv_source_label = Variable(
                torch.LongTensor(Dfeature_out.size(0)).fill_(source_label))
        loss_adv_feature = entropy_loss(Dfeature_out, adv_source_label)

        loss = args.lambda_adv_label * loss_adv_label + args.lambda_adv_feature * loss_adv_feature
        loss.backward()

        loss_adv_label_value += loss_adv_label.data.cpu().numpy()
        loss_adv_feature_value += loss_adv_feature.data.cpu().numpy()

        # train domain label classifier
        for param in model_label.parameters():
            param.requires_grad = True

        # source domain D loss
        pred_source = pred_source.detach()
        D_out = model_label(F.softmax(pred_source, dim=1))
        if usecuda:
            D_source_label = Variable(
                torch.FloatTensor(D_out.data.size()).fill_(source_label).cuda(
                    args.gpu))
        else:
            D_source_label = Variable(
                torch.FloatTensor(D_out.data.size()).fill_(source_label))
        loss_D = bce_loss(D_out, D_source_label)
        loss_D = loss_D / 2
        loss_D.backward()
        loss_Dlabel_value += loss_D.data.cpu().numpy()

        # target domain D loss
        pred_target = pred_target.detach()
        D_out = model_label(F.softmax(pred_target, dim=1))
        if usecuda:
            D_target_label = Variable(
                torch.FloatTensor(D_out.data.size()).fill_(target_label).cuda(
                    args.gpu))
        else:
            D_target_label = Variable(
                torch.FloatTensor(D_out.data.size()).fill_(target_label))
        loss_D = bce_loss(D_out, D_target_label)
        loss_D = loss_D / 2
        loss_D.backward()
        loss_Dlabel_value += loss_D.data.cpu().numpy()

        # train domain feature classifier
        for param in model_feature.parameters():
            param.requires_grad = True

        # train with source
        feature_source = feature_source.detach()
        D_out = model_feature(feature_source)
        if usecuda:
            D_source_label = Variable(
                torch.LongTensor(D_out.size(0)).fill_(source_label).cuda(
                    args.gpu))
        else:
            D_source_label = Variable(
                torch.LongTensor(D_out.size(0)).fill_(source_label))
        loss_D = entropy_loss(D_out, D_source_label)
        loss_D = loss_D / 2
        loss_D.backward()
        loss_Dfeature_value += loss_D.data.cpu().numpy()

        # train with target
        feature_target = feature_target.detach()
        D_out = model_feature(feature_target)
        if usecuda:
            D_target_label = Variable(
                torch.LongTensor(D_out.size(0)).fill_(target_label).cuda(
                    args.gpu))
        else:
            D_target_label = Variable(
                torch.LongTensor(D_out.size(0)).fill_(target_label))
        loss_D = entropy_loss(D_out, D_target_label)
        loss_D = loss_D / 2
        loss_D.backward()
        loss_Dfeature_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_label.step()
        optimizer_feature.step()

        if scheduler is not None:
            scheduler.step(epoch=i_iter)
            args.learning_rate = scheduler.get_lr()[0]

        if scheduler_label is not None:
            scheduler_label.step(epoch=i_iter)
            args.learning_rate_Dl = scheduler_label.get_lr()[0]

        if scheduler_feature is not None:
            scheduler_feature.step(epoch=i_iter)
            args.learning_rate_Df = scheduler_feature.get_lr()[0]

        if (i_iter % 50 == 0):
            print('time = {0},lr = {1: 5f},lr_Dl = {2: 6f},lr_Df = {3: 6f}'.
                  format(datetime.datetime.now(), args.learning_rate,
                         args.learning_rate_Dl, args.learning_rate_Df))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg = {2:.5f} loss_rec = {3:5f} loss_adv1 = {4:.5f}, loss_adv2 = {5:.5f}, loss_Dlabel = {6:.5f} loss_Dfeature = {7:.5f}'
                .format(i_iter, args.num_steps, loss_seg_value, loss_rec_value,
                        loss_adv_label_value, loss_adv_feature_value,
                        loss_Dlabel_value, loss_Dfeature_value))
            print('sdice2 = {0:.5f} sjac2 = {1:.5f}'.format(
                sdice,
                sjac,
            ))

        if i_iter % args.save_pred_every == 0:
            dice, jac = validate_model(model.get_target_segmentation_net(),
                                       valloader, './val/cvlab', i_iter,
                                       args.gpu, usecuda)
            print('val dice: %4f' % dice, 'val jac: %4f' % jac)
            if jac > args.best_tjac:
                args.best_tjac = jac
                print('best val dice: %4f' % dice, 'best val jac: %f' % jac)
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'CVbest' + str(i_iter) + '_' + str(jac) + '.pth'))
                torch.save(
                    model_label.state_dict(),
                    osp.join(
                        args.snapshot_dir,
                        'CVbest' + str(i_iter) + '_' + str(jac) + '_D.pth'))
                torch.save(
                    model_feature.state_dict(),
                    osp.join(
                        args.snapshot_dir,
                        'CVbest' + str(i_iter) + '_' + str(jac) + '_D2.pth'))

            else:
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir, 'CV_' + str(i_iter) + '.pth'))
                torch.save(
                    model_label.state_dict(),
                    osp.join(args.snapshot_dir,
                             'CV_' + str(i_iter) + '_D.pth'))
                torch.save(
                    model_feature.state_dict(),
                    osp.join(args.snapshot_dir,
                             'CV_' + str(i_iter) + '_D2.pth'))