コード例 #1
0
def train(args):

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset, config_file=args.config_file)
    loader = data_loader(data_path,
                         is_transform=True,
                         img_size=(args.img_rows, args.img_cols))
    n_classes = loader.n_classes

    # must use 1 worker for AWS sagemaker without ipc="host" or larger shared memory size
    trainloader = data.DataLoader(loader,
                                  batch_size=args.batch_size,
                                  num_workers=1,
                                  shuffle=True)

    # Setup Model
    model = get_model(args.arch, n_classes)

    # Setup log dir / logging
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    configure(args.log_dir)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.l_rate,
                                momentum=0.9,
                                weight_decay=5e-4)

    step = 0
    for epoch in range(args.n_epoch):
        start_time = time.time()
        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)

            loss = cross_entropy2d(outputs, labels)

            loss.backward()
            optimizer.step()

            log_value('Loss', loss.data[0], step)
            step += 1

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.data[0]),
                      flush=True)

        end_time = time.time()
        print('Epoch run time: %s' % (end_time - start_time))
        torch.save(
            model, args.log_dir + "{}_{}_{}_{}.pt".format(
                args.arch, args.dataset, args.feature_scale, epoch))
コード例 #2
0
ファイル: sunet.py プロジェクト: wenliangdai/sunets
    def __init__(self,
                 num_classes,
                 pretrained=True,
                 ignore_index=-1,
                 weight=None,
                 output_stride='16'):
        super(d_sunet7128, self).__init__()
        self.num_classes = num_classes
        sunet = stackedunet7128(output_stride=output_stride)
        sunet = torch.nn.DataParallel(sunet,
                                      device_ids=range(
                                          torch.cuda.device_count())).cuda()
        if pretrained:
            checkpoint = torch.load(sunet7128_path)
            sunet.load_state_dict(checkpoint['state_dict'])

        self.features = sunet.module.features

        for n, m in self.features.named_modules():
            if 'bn' in n:
                m.momentum = mom_bn

        self.final = nn.Sequential(nn.Conv2d(2304, num_classes, kernel_size=1))

        self.mceloss = cross_entropy2d(ignore=ignore_index,
                                       size_average=False,
                                       weight=weight)
コード例 #3
0
    def __init__(self,
                 num_classes,
                 pretrained=True,
                 use_aux=True,
                 ignore_index=-1,
                 output_stride='16'):
        super(d_resnet18, self).__init__()
        self.use_aux = use_aux
        self.num_classes = num_classes
        resnet = models.resnet18()
        if pretrained:
            resnet.load_state_dict(torch.load(res18_path))
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                    resnet.maxpool)
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        d = dilation[output_stride]
        if d > 1:
            for n, m in self.layer3.named_modules():
                if '0.conv1' in n:
                    m.dilation, m.padding, m.stride = (1, 1), (1, 1), (1, 1)
                elif 'conv1' in n:
                    m.dilation, m.padding, m.stride = (d, d), (d, d), (1, 1)
                elif 'conv2' in n:
                    m.dilation, m.padding, m.stride = (d, d), (d, d), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
        for n, m in self.layer4.named_modules():
            if '0.conv1' in n:
                m.dilation, m.padding, m.stride = (d, d), (d, d), (1, 1)
            elif 'conv1' in n:
                m.dilation, m.padding, m.stride = (2 * d, 2 * d), (2 * d,
                                                                   2 * d), (1,
                                                                            1)
            elif 'conv2' in n:
                m.dilation, m.padding, m.stride = (2 * d, 2 * d), (2 * d,
                                                                   2 * d), (1,
                                                                            1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in chain(self.layer0.named_modules(),
                          self.layer1.named_modules(),
                          self.layer2.named_modules(),
                          self.layer3.named_modules(),
                          self.layer4.named_modules()):
            if 'downsample.1' in n:
                m.momentum = mom_bn
            elif 'bn' in n:
                m.momentum = mom_bn

        self.final = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=mom_bn), nn.ReLU(inplace=True),
            nn.Dropout(0.1), nn.Conv2d(512, num_classes, kernel_size=1))

        self.mceloss = cross_entropy2d(ignore=ignore_index, size_average=False)
コード例 #4
0
def train(args):
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path,
                         is_transform=True,
                         img_size=(args.img_rows, args.img_cols))
    n_classes = loader.n_classes
    trainloader = data.DataLoader(loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    # Setup visdom for visualization
    vis = visdom.Visdom()

    loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                           Y=torch.zeros((1)).cpu(),
                           opts=dict(xlabel='minibatches',
                                     ylabel='Loss',
                                     title='Training Loss',
                                     legend=['Loss']))

    # Setup Model
    model = get_model(args.arch, n_classes)

    if torch.cuda.is_available():
        model.cuda(0)
        test_image, test_segmap = loader[0]
        test_image = Variable(test_image.unsqueeze(0).cuda(0))
    else:
        test_image, test_segmap = loader[0]
        test_image = Variable(test_image.unsqueeze(0))

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.l_rate,
                                momentum=0.99,
                                weight_decay=5e-4)
    for epoch in range(args.n_epoch):
        for i, (images, labels) in enumerate(trainloader):
            if torch.cuda.is_available():
                images = Variable(images.cuda(0))
                labels = Variable(labels.cuda(0))
            else:
                images = Variable(images)
                labels = Variable(labels)

            # iter = len(trainloader) * epoch + i
            # poly_lr_scheduler(optimizer, args.l_rate, iter)

            optimizer.zero_grad()
            outputs = model(images)

            loss = cross_entropy2d(outputs, labels)

            loss.backward()
            optimizer.step()

            vis.line(X=torch.ones((1, 1)).cpu() * i,
                     Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
                     win=loss_window,
                     update='append')

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.data[0]))

        # test_output = model(test_image)
        # predicted = loader.decode_segmap(test_output[0].cpu().data.numpy().argmax(0))
        # target = loader.decode_segmap(test_segmap.numpy())

        # vis.image(test_image[0].cpu().data.numpy(), opts=dict(title='Input' + str(epoch)))
        # vis.image(np.transpose(target, [2,0,1]), opts=dict(title='GT' + str(epoch)))
        # vis.image(np.transpose(predicted, [2,0,1]), opts=dict(title='Predicted' + str(epoch)))

        torch.save(
            model, "{}_{}_{}_{}.pkl".format(args.arch, args.dataset,
                                            args.feature_scale, epoch))
コード例 #5
0
def train(args, guotai_config):

    # Setup Dataloader
    print('###### Step One: Setup Dataloader')
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)

    # For 2D dataset keep is_transform True
    # loader = data_loader(data_path, is_transform=True, img_size=(args.img_rows, args.img_cols))

    # For 3D dataset keep is_transform False
    # loader = data_loader(data_path, is_transform=False, img_size=(args.img_rows, args.img_cols))
    if args.dataset == 'brats17_loader_guotai':
        config_data = config['data']
        # print(config_data)
        # config_net = config['network']
        config_train = config['training']
        random.seed(config_train.get('random_seed', 1))
        assert (config_data['with_ground_truth'])
        # net_type = config_net['net_type']
        # net_name = config_net['net_name']
        # class_num = config_net['class_num']
        # batch_size = config_data.get('batch_size', 5)
        dataloader_guotai = DataLoader(config_data)
        dataloader_guotai.load_data()
        loader = data_loader(dataloader_guotai)
    elif args.dataset == 'brats17_loader':
        loader = data_loader(data_path,
                             is_transform=False,
                             img_size=(args.img_rows, args.img_cols))
    else:
        loader = data_loader(data_path,
                             is_transform=True,
                             img_size=(args.img_rows, args.img_cols))

    n_classes = args.n_classes
    trainloader = data.DataLoader(loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    # Setup Model
    print('###### Step Two: Setup Model')
    model = get_model(args.arch, n_classes)
    if args.pretrained_path != 'empty':
        model = torch.load(args.pretrained_path)
    #model = torch.load('/home/donghao/Desktop/donghao/isbi2019/code/fast_segmentation_code/runs/bisenet3Dbrain_brats17_loader_1_251_3020_min.pkl')
    #model = torch.load('/home/donghao/Desktop/donghao/isbi2019/code/fast_segmentation_code/runs/2177/bisenet3Dbrain_brats17_loader_1_293_min.pkl')
    #model = torch.load('/home/donghao/Desktop/donghao/isbi2019/code/fast_segmentation_code/runs/9863/FCDenseNet57_brats17_loader_1_599.pkl')
    # model =
    if torch.cuda.is_available():
        model.cuda(0)
        test_image, test_segmap = loader[0]
        test_image = Variable(test_image.unsqueeze(0).cuda(0))
    else:
        test_image, test_segmap = loader[0]
        test_image = Variable(test_image.unsqueeze(0))

    log('The optimizer is Adam')
    log('The learning rate is {}'.format(args.l_rate))
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
    # optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.99)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate)
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Train Model
    print('###### Step Three: Training Model')

    epoch_loss_array_total = np.zeros([1, 2])
    for epoch in range(args.n_epoch):
        img_counter = 1
        loss_sum = 0
        for i, (images, labels) in enumerate(trainloader):
            img_counter = img_counter + 1
            if torch.cuda.is_available():
                images = Variable(images.cuda(0))
                labels = Variable(labels.cuda(0))
            else:
                images = Variable(images)
                labels = Variable(labels)

            optimizer.zero_grad()
            # log('The maximum value of input image is {}'.format(images.max()))
            # print(images)
            outputs = model(images)
            # print(outputs)
            if args.arch == 'bisenet3Dbrain' or args.arch == 'unet3d_cls' or args.arch == 'FCDenseNet57' or args.arch == 'FCDenseNet103':
                loss = cross_entropy3d(outputs, labels)
            elif args.arch == 'unet3d_res':
                labels = labels * 40
                labels = labels + 1
                log('The unique value of labels are {}'.format(
                    np.unique(labels)))
                log('The maximum of outputs are {}'.format(outputs.max()))
                log('The size of output is {}'.format(outputs.size()))
                log('The size of labels is {}'.format(labels.size()))
                loss = nn.L1Loss()
                labels = labels.type(torch.cuda.FloatTensor)
                outputs = torch.squeeze(outputs, dim=1)

                loss = loss(outputs, labels)
            else:
                loss = cross_entropy2d(outputs, labels)
            loss.backward()
            optimizer.step()
            loss_sum = loss_sum + torch.Tensor([loss.data]).unsqueeze(0).cpu()

        avg_loss = loss_sum / img_counter
        avg_loss_array = np.array(avg_loss)
        epoch_loss_array_total = np.concatenate(
            (epoch_loss_array_total, [[avg_loss_array[0][0], epoch]]), axis=0)
        print('The current loss of epoch', epoch, 'is', avg_loss_array[0][0])
        # training model will be saved
        log('The variable avg_loss_array is {}'.format(avg_loss_array))
        writer.add_scalar('train_main_loss', avg_loss_array[0][0], epoch)

        if epoch % 10 == 0:
            torch.save(
                model,
                "runs/{}/{}_{}_{}_{}.pkl".format(rand_int, args.arch,
                                                 args.dataset,
                                                 args.feature_scale, epoch))
    # I guess the shape is (epoch, 2)
    log('epoch_loss_array_total is {}'.format(epoch_loss_array_total))
    # The shape of epoch_loss_array_total is (epoch, 2)
    log('the shape of epoch_loss_array_total is {}'.format(
        epoch_loss_array_total.shape))
    epoch_loss_array_total = np.delete(arr=epoch_loss_array_total,
                                       obj=0,
                                       axis=0)
    log('the shape of epoch_loss_array_total after removal is {}'.format(
        epoch_loss_array_total.shape))
    loss_min_indice = np.argmin(epoch_loss_array_total, axis=0)
    log('The loss_min_indice is {}'.format(loss_min_indice))
    torch.save(
        model,
        "runs/{}/{}_{}_{}_{}_min.pkl".format(rand_int, args.arch, args.dataset,
                                             args.feature_scale,
                                             loss_min_indice[0]))
    sys.stdout = orig_stdout
    f.close()
コード例 #6
0
ファイル: train_mangrove.py プロジェクト: flixpar/CoralSemSeg
def train(args):

    # Setup Dataloader
    data_loader = get_loader("mangrove")
    data_path = get_data_path("mangrove")
    loader = data_loader(data_path, img_size=args.img_size)

    n_classes = loader.n_classes
    n_channels = loader.n_channels

    trainloader = data.DataLoader(loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    # Setup visdom for visualization
    vis = visdom.Visdom()

    loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                           Y=torch.zeros((1)).cpu(),
                           opts=dict(xlabel='minibatches',
                                     ylabel='Loss',
                                     title='Training Loss',
                                     legend=['Loss']))

    # Setup Model
    # default: coralnet
    model = get_model("coralnet", n_classes, in_channels=n_channels)

    if torch.cuda.is_available():
        model.cuda(0)
        test_image, test_segmap = loader[0]
        test_image = Variable(test_image.unsqueeze(0).cuda(0))
    else:
        print("CUDA Error.")
        test_image, test_segmap = loader[0]
        test_image = Variable(test_image.unsqueeze(0))

    # optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.99, weight_decay=5e-4)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    for epoch in range(args.n_epoch + 1):
        for i, (images, labels) in enumerate(trainloader):
            if torch.cuda.is_available():
                images = Variable(images.cuda(0))
                labels = Variable(labels.cuda(0))
            else:
                images = Variable(images)
                labels = Variable(labels)

            iter = len(trainloader) * epoch + i
            # poly_lr_scheduler(optimizer, args.learning_rate, iter)

            optimizer.zero_grad()
            outputs = model(images)

            loss = cross_entropy2d(outputs, labels)

            loss.backward()
            optimizer.step()

            vis.line(X=torch.ones((1, 1)).cpu() * i,
                     Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
                     win=loss_window,
                     update='append')

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.data[0]))

        if epoch % 50 == 0:
            torch.save(
                model,
                "training/{}_{}_{}_{}.pkl".format("mangrove_multi", "linknet",
                                                  args.feature_scale,
                                                  epoch + 1))
コード例 #7
0
def train(args):

    # writer = SummaryWriter('exp')
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path,
                         is_transform=True,
                         img_size=(args.img_rows, args.img_cols))
    n_classes = loader.n_classes
    trainloader = data.DataLoader(loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    # Setup visdom for visualization
    # vis = visdom.Visdom()

    # loss_window = vis.line(X=torch.zeros((1,)).cpu(),
    #                        Y=torch.zeros((1)).cpu(),
    #                        opts=dict(xlabel='epoch',
    #                                  ylabel='Loss',
    #                                  title='Training Loss',
    #                                  legend=['Loss']))

    # Setup Model
    model = get_model(args.arch, n_classes)

    if torch.cuda.is_available():
        model.cuda(0)
        test_image, test_segmap = loader[0]
        test_image = Variable(test_image.unsqueeze(0).cuda(0))
    else:
        test_image, test_segmap = loader[0]
        test_image = Variable(test_image.unsqueeze(0))

    # optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate * 100)
    for epoch in range(args.n_epoch):
        img_counter = 1
        loss_sum = 0
        for i, (images, labels) in enumerate(trainloader):
            img_counter = img_counter + 1
            # print(img_counter)
            # print('i value: ', i)
            if torch.cuda.is_available():
                images = Variable(images.cuda(0))
                labels = Variable(labels.cuda(0))
            else:
                images = Variable(images)
                labels = Variable(labels)

            # iter = len(trainloader)*epoch + i
            # poly_lr_scheduler(optimizer, args.l_rate, iter)

            optimizer.zero_grad()
            outputs = model(images)
            loss = cross_entropy2d(outputs, labels)
            loss.backward()

            optimizer.step()

            # vis.line(
            #     X=torch.ones((1, 1)).cpu() * i,
            #     Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
            #     win=loss_window,
            #     update='append')

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.data[0]))
            # print('The number of img_counter is ', img_counter)
            loss_sum = loss_sum + torch.Tensor([loss.data[0]
                                                ]).unsqueeze(0).cpu()
        avg_loss = loss_sum / img_counter
        avg_loss_array = np.array(avg_loss)
        # vis.line(
        #     X=torch.ones((1, 1)).cpu() * epoch,
        #     Y=avg_loss,
        #     win=loss_window,
        #     update='append')
        # writer.add_scalar('train_main_loss', avg_loss, epoch)
        test_output = model(test_image)
        predicted = loader.decode_segmap(
            test_output[0].cpu().data.numpy().argmax(0))
        target = loader.decode_segmap(test_segmap.numpy())
        # if epoch == 1:
        #     vis.image(test_image[0].cpu().data.numpy(), opts=dict(title='Test Epoch' + str(epoch)))
        #     vis.image(np.transpose(target, [2,0,1]), opts=dict(title='GT Epoch' + str(epoch)))
        # vis.image(np.transpose(predicted, [2,0,1]), opts=dict(title='Predicted Epoch' + str(epoch)))

        torch.save(
            model, "{}_{}_{}_{}.pkl".format(args.arch, args.dataset,
                                            args.feature_scale, epoch))
コード例 #8
0
def train(args):
    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path,
                         is_transform=True,
                         img_size=(args.img_rows, args.img_cols))
    n_classes = loader.n_classes
    trainloader = data.DataLoader(loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    # Setup Model
    net_features, net_segmenter = get_model(name=args.arch,
                                            n_classes=n_classes)
    netD = _netD()
    netG = _netG(nz=nz)
    padder = padder_layer(pad_size=100)

    criterion_gan = nn.BCELoss()

    input = torch.FloatTensor(args.batch_size_gan, 3, args.img_rows,
                              args.img_cols)
    noise = torch.FloatTensor(args.batch_size_gan, nz, 1, 1)
    fixed_noise = torch.FloatTensor(args.batch_size_gan, nz, 1,
                                    1).normal_(0, 1)
    label = torch.FloatTensor(args.batch_size_gan)

    real_label = .9
    fake_label = .1

    G = args.gpu

    if torch.cuda.is_available():
        input, label = input.cuda(G), label.cuda(G)
        noise, fixed_noise = noise.cuda(G), fixed_noise.cuda(G)

    fixed_noise = Variable(fixed_noise)

    if torch.cuda.is_available():
        net_features.cuda(G)
        net_segmenter.cuda(G)
        padder.cuda(G)
        netD.cuda(G)
        netG.cuda(G)
        criterion_gan.cuda(G)

    optimizerS = torch.optim.SGD(list(net_features.parameters()) +
                                 list(net_segmenter.parameters()),
                                 lr=args.l_rate,
                                 momentum=0.99,
                                 weight_decay=5e-4)
    optimizerD = torch.optim.Adam(list(net_features.parameters()) +
                                  list(netD.parameters()),
                                  lr=args.l_rate,
                                  betas=(args.beta1, 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=args.l_rate,
                                  betas=(args.beta1, 0.999))

    SS = 1
    GAN = 0

    for epoch in range(args.n_epoch):
        for i, (images, labels) in enumerate(trainloader):

            real_cpu, label_cpu = data
            batch_size = real_cpu.size(0)
            if torch.cuda.is_available():
                real_cpu = real_cpu.cuda(G)
                label_cpu = label_cpu.cuda(G)
            input.resize_as_(real_cpu).copy_(real_cpu)
            inputv = Variable(input)
            labelv_semantic = Variable(label_cpu)

            # iter = len(trainloader)*epoch + i
            # poly_lr_scheduler(optimizer, args.l_rate, iter)

            if GAN:
                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                optimizerD.zero_grad()
                netD.zero_grad()
                net_features.zero_grad()

                # train with real
                label.resize_(batch_size).fill_(real_label)
                labelv = Variable(label)
                inputv_feats = net_features(inputv)
                output = netD(inputv_feats)
                errD_real = criterion_gan(output, labelv)
                errD_real.backward()
                D_x = output.data.mean()

                # train with fake
                noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
                noisev = Variable(noise)
                fake = netG(noisev)
                labelv = Variable(label.fill_(fake_label))
                fake_feats = net_features(fake.detach())
                output = netD(fake_feats)
                errD_fake = criterion_gan(output, labelv)
                errD_fake.backward()
                D_G_z1 = output.data.mean()

                errD = errD_real + errD_fake
                optimizerD.step()

                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################
                optimizerG.zero_grad()
                netG.zero_grad()

                label.resize_(batch_size).fill_(real_label)
                labelv = Variable(label.fill_(
                    real_label))  # fake labels are real for generator cost
                noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
                noise2v = Variable(noise)
                fake2 = netG(noise2v)
                fake_feats2 = net_features(fake2)
                output = netD(fake_feats2)
                errG = criterion_gan(output, labelv)
                errG.backward()
                D_G_z2 = output.data.mean()
                optimizerG.step()

            if SS:
                optimizerS.zero_grad()
                net_features.zero_grad()
                net_segmenter.zero_grad()

                padded_images = padder(inputv)
                features = net_features(padded_images)
                outputs = net_segmenter(features)

                loss = cross_entropy2d(outputs, labelv_semantic)

                loss.backward()
                optimizerS.step()

                if (i + 1) % 20 == 0:
                    print("Iter [%d/%d], Epoch [%d/%d] Loss: %.4f" %
                          (i + 1, len(trainloader), epoch + 1, args.n_epoch,
                           loss.data[0]))

        # test_output = model(test_image)
        # predicted = loader.decode_segmap(test_output[0].cpu().data.numpy().argmax(0))
        # target = loader.decode_segmap(test_segmap.numpy())

        # vis.image(test_image[0].cpu().data.numpy(), opts=dict(title='Input' + str(epoch)))
        # vis.image(np.transpose(target, [2,0,1]), opts=dict(title='GT' + str(epoch)))
        # vis.image(np.transpose(predicted, [2,0,1]), opts=dict(title='Predicted' + str(epoch)))

        torch.save((net_features, net_segmenter),
                   "{}_{}_{}_{}.pkl".format(args.arch, args.dataset,
                                            args.feature_scale, epoch))
コード例 #9
0
ファイル: fcn_dcgan.py プロジェクト: monaj07/pytorch_semseg
def train(args):
    # Setup Dataloader
    if args.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        assert (args.img_cols == args.img_rows)
        dataset = dset.ImageFolder(root=args.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Scale(args.img_cols),
                                       transforms.CenterCrop(args.img_cols),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif args.dataset in ['pascal', 'camvid']:
        data_loader = get_loader(args.dataset)
        data_path = get_data_path(args.dataset)
        dataset = data_loader(data_path,
                              is_transform=True,
                              img_size=(args.img_rows, args.img_cols))

    n_classes = dataset.n_classes
    trainloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              num_workers=4,
                                              shuffle=True)

    ############################
    # Setup the Models
    net_features, net_segmenter = get_model(name=args.arch,
                                            n_classes=n_classes)
    netD = _netD()
    netG = _netG(nz=nz)
    padder = padder_layer(pad_size=100)
    ############################

    ############################
    ### Initialization:
    if args.netD != '':
        netD.load_state_dict(torch.load(os.path.join(args.outf, args.netD)))
    if args.netG != '':
        netG.load_state_dict(torch.load(os.path.join(args.outf, args.netG)))
    if args.net_features != '':
        netD.load_state_dict(
            torch.load(os.path.join(args.outf, args.net_features)))
    if args.net_segmenter != '':
        netD.load_state_dict(
            torch.load(os.path.join(args.outf, args.net_segmenter)))
    ############################

    criterion_gan = nn.BCELoss()

    input = torch.FloatTensor(args.batch_size_gan, 3, args.img_rows,
                              args.img_cols)
    noise = torch.FloatTensor(args.batch_size_gan, nz, 1, 1)
    fixed_noise = torch.FloatTensor(args.batch_size_gan, nz, 1,
                                    1).normal_(0, 1)
    label = torch.FloatTensor(args.batch_size_gan)

    real_label = .9
    fake_label = .1

    G = args.gpu

    if torch.cuda.is_available():
        input, label = input.cuda(G), label.cuda(G)
        noise, fixed_noise = noise.cuda(G), fixed_noise.cuda(G)

    fixed_noise = Variable(fixed_noise)

    if torch.cuda.is_available():
        net_features.cuda(G)
        net_segmenter.cuda(G)
        padder.cuda(G)
        netD.cuda(G)
        netG.cuda(G)
        criterion_gan.cuda(G)

    optimizerS = torch.optim.SGD(list(net_features.parameters()) +
                                 list(net_segmenter.parameters()),
                                 lr=args.l_rate,
                                 momentum=0.99,
                                 weight_decay=5e-4)
    optimizerD = torch.optim.Adam(list(net_features.parameters()) +
                                  list(netD.parameters()),
                                  lr=args.l_rate,
                                  betas=(args.beta1, 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=args.l_rate,
                                  betas=(args.beta1, 0.999))

    SS = 1
    GAN = 0

    for epoch in range(args.n_epoch):
        for i, data in enumerate(trainloader):

            real_cpu, label_cpu = data
            batch_size = real_cpu.size(0)
            if torch.cuda.is_available():
                real_cpu = real_cpu.cuda(G)
                label_cpu = label_cpu.cuda(G)
            input.resize_as_(real_cpu).copy_(real_cpu)
            inputv = Variable(input)
            labelv_semantic = Variable(label_cpu)

            # iter = len(trainloader)*epoch + i
            # poly_lr_scheduler(optimizer, args.l_rate, iter)

            if GAN:
                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                optimizerD.zero_grad()
                net_features.zero_grad()
                netD.zero_grad()

                # train with real
                label.resize_(batch_size).fill_(real_label)
                labelv = Variable(label)
                inputv_feats = net_features(inputv)
                output = netD(inputv_feats)
                errD_real = criterion_gan(output, labelv)
                errD_real.backward()
                D_x = output.data.mean()

                # train with fake
                noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
                noisev = Variable(noise)
                fake = netG(noisev)
                labelv = Variable(label.fill_(fake_label))
                fake_feats = net_features(fake.detach())
                output = netD(fake_feats)
                errD_fake = criterion_gan(output, labelv)
                errD_fake.backward()
                D_G_z1 = output.data.mean()

                errD = errD_real + errD_fake
                optimizerD.step()

                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################
                optimizerG.zero_grad()
                netG.zero_grad()

                label.resize_(batch_size).fill_(real_label)
                labelv = Variable(label.fill_(
                    real_label))  # fake labels are real for generator cost
                noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
                noise2v = Variable(noise)
                fake2 = netG(noise2v)
                fake_feats2 = net_features(fake2)
                output = netD(fake_feats2)
                errG = criterion_gan(output, labelv)
                errG.backward()
                D_G_z2 = output.data.mean()
                optimizerG.step()

            if SS:
                optimizerS.zero_grad()
                net_features.zero_grad()
                net_segmenter.zero_grad()

                padded_images = padder(inputv)
                features = net_features(padded_images)
                outputs = net_segmenter(features)

                loss = cross_entropy2d(outputs, labelv_semantic)

                loss.backward()
                optimizerS.step()

            ######################################################
            ### Loss Report:
            ######################################################

            if SS:
                if (i + 1) % 20 == 0:
                    print("Iter [%d/%d], Epoch [%d/%d] Loss: %.4f" %
                          (i + 1, len(trainloader), epoch + 1, args.n_epoch,
                           loss.data[0]))

            if GAN:
                if (i) % 100 == 0 and i > 0:
                    print(
                        '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                        % (epoch, args.n_epoch, i, len(trainloader),
                           errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2))
                if i % 500 == 0 and i > 0:
                    vutils.save_image(real_cpu,
                                      '%s/real_samples.png' % args.outf,
                                      normalize=True)
                    fake = netG(fixed_noise)
                    vutils.save_image(fake.data,
                                      '%s/fake_samples_epoch_%03d_%03d.png' %
                                      (args.outf, epoch, i),
                                      normalize=True)
                if i % 3000 == 0 and i > 0:
                    torch.save(
                        net_features.state_dict(),
                        '%s/net_features_epoch_%d_%d.pth' %
                        (args.outf, epoch, i))
                    torch.save(
                        netG.state_dict(),
                        '%s/netG_epoch_%d_%d.pth' % (args.outf, epoch, i))
                    torch.save(
                        netD.state_dict(),
                        '%s/netD_epoch_%d_%d.pth' % (args.outf, epoch, i))

        ######################################################
        ### Do checkpointing:
        ######################################################
        torch.save((net_features, net_segmenter),
                   "{}_{}_{}_{}.pkl".format(args.arch, args.dataset,
                                            args.feature_scale, epoch))
        continue
        torch.save(net_features.state_dict(),
                   '%s/net_features_epoch_%d.pth' % (args.outf, epoch))
        torch.save(netG.state_dict(),
                   '%s/netG_epoch_%d.pth' % (args.outf, epoch))
        torch.save(netD.state_dict(),
                   '%s/netD_epoch_%d.pth' % (args.outf, epoch))
コード例 #10
0
ファイル: train_pascal.py プロジェクト: weigq/pytorch-semseg
def train(model):

    if model == 'unet':
        model = unet(feature_scale=feature_scale,
                     n_classes=n_classes,
                     is_batchnorm=True,
                     in_channels=3,
                     is_deconv=True)

    if model == 'segnet':
        model = segnet(n_classes=n_classes, in_channels=3, is_unpooling=True)

    if model == 'fcn32':
        model = fcn32s(n_classes=n_classes)
        vgg16 = models.vgg16(pretrained=True)
        model.init_vgg16_params(vgg16)

    if model == 'fcn16':
        model = fcn16s(n_classes=n_classes)
        vgg16 = models.vgg16(pretrained=True)
        model.init_vgg16_params(vgg16)

    if model == 'fcn8':
        model = fcn8s(n_classes=n_classes)
        vgg16 = models.vgg16(pretrained=True)
        model.init_vgg16_params(vgg16)

    pascal = pascalVOCLoader(data_path, is_transform=True, img_size=img_rows)
    trainloader = data.DataLoader(pascal, batch_size=batch_size, num_workers=4)

    if torch.cuda.is_available():
        model.cuda(0)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=l_rate,
                                momentum=0.99,
                                weight_decay=5e-4)

    test_image, test_segmap = pascal[0]
    test_image = Variable(test_image.unsqueeze(0).cuda(0))
    vis = visdom.Visdom()

    for epoch in range(n_epoch):
        for i, (images, labels) in enumerate(trainloader):
            if torch.cuda.is_available():
                images = Variable(images.cuda(0))
                labels = Variable(labels.cuda(0))
            else:
                images = Variable(images)
                labels = Variable(labels)

            optimizer.zero_grad()
            outputs = model(images)

            loss = cross_entropy2d(outputs, labels)

            loss.backward()
            optimizer.step()

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, n_epoch, loss.data[0]))

        test_output = model(test_image)
        predicted = pascal.decode_segmap(
            test_output[0].cpu().data.numpy().argmax(0))
        target = pascal.decode_segmap(test_segmap.numpy())

        vis.image(test_image[0].cpu().data.numpy(),
                  opts=dict(title='Input' + str(epoch)))
        vis.image(np.transpose(target, [2, 0, 1]),
                  opts=dict(title='GT' + str(epoch)))
        vis.image(np.transpose(predicted, [2, 0, 1]),
                  opts=dict(title='Predicted' + str(epoch)))

    torch.save(model, "unet_voc_" + str(feature_scale) + ".pkl")
コード例 #11
0
ファイル: train.py プロジェクト: monaj07/pytorch_semseg
def train(args):

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path, split=args.split, is_transform=True, img_size=(args.img_rows, args.img_cols))
    n_classes = loader.n_classes
    trainloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=4, shuffle=True)


    # Setup Model
    pre_trained = True
    model = get_model(args.arch, n_classes, pre_trained=pre_trained)

    ############################################
    # Random weights initialization
    if not pre_trained:
        model.apply(weights_init)
    ############################################


    if args.restore_from != '':
        print('\n' + '-' * 40)
        model = torch.load(args.restore_from)
        print('Restored the trained network.')
        print('-' * 40)

    if torch.cuda.is_available():
        model.cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4)

    for epoch in range(args.n_epoch):
        for i, (images, labels) in enumerate(trainloader):
            if torch.cuda.is_available():
                images = Variable(images.cuda(args.gpu))
                labels = Variable(labels.cuda(args.gpu))
            else:
                images = Variable(images)
                labels = Variable(labels)

            #iter = len(trainloader)*epoch + i
            adjust_learning_rate(optimizer, args.l_rate, epoch)

            optimizer.zero_grad()
            outputs = model(images)

            loss = cross_entropy2d(outputs, labels)

            loss.backward()
            optimizer.step()


            if (i+1) % 20 == 0:
                print("Iter [%d/%d], Epoch [%d/%d] Loss: %.4f" % (i+1, len(trainloader), epoch+1, args.n_epoch, loss.data[0]))

        # test_output = model(test_image)
        # predicted = loader.decode_segmap(test_output[0].cpu().data.numpy().argmax(0))
        # target = loader.decode_segmap(test_segmap.numpy())

        # vis.image(test_image[0].cpu().data.numpy(), opts=dict(title='Input' + str(epoch)))
        # vis.image(np.transpose(target, [2,0,1]), opts=dict(title='GT' + str(epoch)))
        # vis.image(np.transpose(predicted, [2,0,1]), opts=dict(title='Predicted' + str(epoch)))
        if args.restore_from != '':
            torch.save(model, "./{}/{}_{}_{}_from_{}.pkl".format(args.save_folder, args.arch, args.dataset, epoch, args.restore_from))
        else:
            torch.save(model, "./{}/{}_{}_{}.pkl".format(args.save_folder, args.arch, args.dataset, epoch))
コード例 #12
0
def train(args):
    ###################
    pre_trained = args.pre_trained # default='gt'
    ###################

    ############################################
    if pre_trained == 'gt':
        # When pre_trained = 'gt', i.e. when using supervised image-net weights, images were normalized with image-net mean.
        image_transform = None
    else:
        # When pre_trained = 'self' or 'no', i.e. in the self-supervised case, or unsupervised case, the input images are normalized this way:
        image_transform = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path, split=args.split, is_transform=True,
                         img_size=(args.img_rows, args.img_cols), image_transform=image_transform)
    n_classes = loader.n_classes
    trainloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=4, shuffle=True)
    ############################################

    ############################################
    # Setup Model
    if args.netF_path == '' or args.netS_path == '':
        if pre_trained == 'gt':
            netF, netS = get_model(args.arch, n_classes, pre_trained=True)
            print('*' * 60)
            print('*' * 60)
            print('*' * 60)
            print('Training using a GT-supervised pre-trained model ...')
            print('*' * 60)
            print('*' * 60)
            print('*' * 60 + '\n')
        elif pre_trained == 'self':
            netF, netS = get_model(args.arch, n_classes, pre_trained=False)
            assert (args.model_path != '')
            # Loading a self-supervised trained model (ss_model)
            print('x' * 60)
            print('x' * 60)
            print('x' * 60)
            print('Training a using a self-supervised pre-trained model ...')
            ss_model = torch.load(args.model_path)
            netF.init_alex_params(ss_model)
            #netS.init_alex_params(ss_model)
            netS.apply(weights_init)
            print('Restored the self-supervised pre_trained model {}'.format(args.model_path))
            print('x' * 60)
            print('x' * 60)
            print('x' * 60 + '\n')
        else: # pre_trained == 'no':
            netF, netS = get_model(args.arch, n_classes, pre_trained=False)
            # Random weight initialization
            netF.apply(weights_init)
            netS.apply(weights_init)
            print('o' * 60)
            print('o' * 60)
            print('o' * 60)
            print('Training without any pre-trained model ...')
            print('o' * 60)
            print('o' * 60)
            print('o' * 60 + '\n')

    padder = padder_layer(pad_size=100)
    ############################################

    ############################################
    # If resuming the training from a saved model
    if args.netF_path != '':
        print('\n' + '-' * 40)
        netF = torch.load(args.netF_path)
        print('Resuming the training netF from model in {}'.format(args.netF_path))
    if args.netS_path != '':
        netS = torch.load(args.netS_path)
        print('Resuming the training netS from model in {}'.format(args.netS_path))
        print('-' * 40)
    ############################################

    ############################################
    # Porting the networks to CUDA
    if torch.cuda.is_available():
        netF.cuda(args.gpu)
        netS.cuda(args.gpu)
        padder.cuda(args.gpu)
    ############################################

    ############################################
    # Defining the optimizer over the network parameters
    optimizerSS = torch.optim.SGD([{'params': netF.features.parameters()},
                                        {'params': netF.classifier.parameters(), 'lr':10*args.l_rate},
                                        {'params': netS.parameters(), 'lr':20*args.l_rate}],
                                  lr=args.l_rate, momentum=0.99, weight_decay=5e-4)
    optimizerSS_init = copy.deepcopy(optimizerSS)
    ############################################


    ############################################
    # TRAINING:
    for epoch in range(args.n_epoch):
        for i, (images, labels) in enumerate(trainloader):
            ######################
            # Porting the data to Autograd variables and CUDA (if available)
            if torch.cuda.is_available():
                images = Variable(images.cuda(args.gpu))
                labels = Variable(labels.cuda(args.gpu))
            else:
                images = Variable(images)
                labels = Variable(labels)

            ######################
            # Scheduling the learning rate
            if args.pre_trained=='no':
                adjust_learning_rate(optimizerSS, args.l_rate, epoch, step=20)
            else:
                adjust_learning_rate_v2(optimizerSS, optimizerSS_init, epoch, step=20)

            ######################
            # Setting the gradients to zero at each iteration
            optimizerSS.zero_grad()
            netF.zero_grad()
            netS.zero_grad()

            ######################
            # Passing the data through the networks
            padded_images = padder(images)
            feature_maps = netF(padded_images)
            score_maps = netS(feature_maps)
            outputs = F.upsample(score_maps, labels.size()[1:], mode='bilinear')

            ######################
            # Computing the loss and doing back-propagation
            loss = cross_entropy2d(outputs, labels)
            loss.backward()

            ######################
            # Updating the parameters
            optimizerSS.step()

            if (i+1) % 20 == 0:
                print("Iter [%d/%d], Epoch [%d/%d] Loss: %.4f" % (i+1, len(trainloader), epoch+1, args.n_epoch, loss.data[0]))

        # test_output = model(test_image)
        # predicted = loader.decode_segmap(test_output[0].cpu().data.numpy().argmax(0))
        # target = loader.decode_segmap(test_segmap.numpy())

        # vis.image(test_image[0].cpu().data.numpy(), opts=dict(title='Input' + str(epoch)))
        # vis.image(np.transpose(target, [2,0,1]), opts=dict(title='GT' + str(epoch)))
        # vis.image(np.transpose(predicted, [2,0,1]), opts=dict(title='Predicted' + str(epoch)))

        torch.save(netF, "./{}/netF_{}_{}_{}.pkl".format(args.save_folder, args.arch, args.dataset, epoch))
        torch.save(netS, "./{}/netS_{}_{}_{}.pkl".format(args.save_folder, args.arch, args.dataset, epoch))
コード例 #13
0
def train(args):

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path,
                         is_transform=True,
                         img_size=(args.img_rows, args.img_cols))
    n_classes = loader.n_classes
    trainloader = data.DataLoader(loader,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()

        loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss',
                                         legend=['Loss']))

    # Setup Model
    model = get_model(args.arch, n_classes)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.l_rate,
                                momentum=0.99,
                                weight_decay=5e-4)

    for epoch in range(args.n_epoch):
        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)

            loss = cross_entropy2d(outputs, labels)

            loss.backward()
            optimizer.step()

            if args.visdom:
                vis.line(X=torch.ones((1, 1)).cpu() * i,
                         Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
                         win=loss_window,
                         update='append')

            if (i + 1) % 20 == 0:
                print("%s / Epoch [%d/%d] Loss: %.4f" %
                      (args.arch, epoch + 1, args.n_epoch, loss.data[0]))

        # torch.save(model, "{}_{}_{}_{}.pkl".format(args.arch, args.dataset, args.feature_scale, epoch))
        torch.save(
            model, "{}_{}_{}.pkl".format(args.arch, args.dataset,
                                         args.feature_scale))
コード例 #14
0
def train(args, model, optimizer, dataset, episode=0):
    
    trainloader = data.DataLoader(dataset, batch_size=args.batch_size, num_workers=8, shuffle=True, drop_last=True)

    class_weight = Variable(dataset.class_weight.cuda())

    lr = args.l_rate
    n_epoch = args.n_epoch
    optimizer.param_groups[0]['lr'] = args.l_rate
    model.train()

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(port=args.visdom)
        loss_window = vis.line(X=np.column_stack((np.zeros((1,)))),
                               Y=np.column_stack((np.zeros((1)))),
                               opts=dict(xlabel='epoch',
                                         ylabel='Loss',
                                         title=args.mode + '_' + args.exp_name + '_Episode_' + str(episode),
                                         legend=['Train Loss']))

    t1 = time.time()
    start_epoch = args.start_epoch if episode == args.start_episode else 0
    best_iou = -100.0
    save_interval = int(floor(n_epoch*args.save_percent))
    for epoch in range(1 + start_epoch, n_epoch + 1):
        utils.adjust_learning_rate(optimizer, args.l_rate, args.lr_decay, 
                                     epoch - 1, 1)
        for i, (images, labels, image_name) in enumerate(trainloader):
            
            images = Variable(images.cuda())
            labels = Variable(labels.cuda(async=True))
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = cross_entropy2d(outputs, labels, class_weight)
            loss.backward()
            optimizer.step()
        if epoch % (save_interval*args.eval_interval) == 0:
            gts, preds, uncts = test(args, model=model, split='val')
            model.train()
            _, score = eval_metrics(args, gts, preds, verbose=False)
            print 'val Mean IoU: ', score['Mean IoU : \t']
            if score['Mean IoU : \t'] >= best_iou:
                best_iou = score['Mean IoU : \t']
                state = {'episode': episode, 
                         'epoch': epoch,
                         'model_state': model.state_dict(),
                        'optimizer_state' : optimizer.state_dict(),}
                print "update best model {}".format(best_iou)
                torch.save(state, "checkpoint/{}/{}_{}_{}_best_model.pkl".format(\
                                        args.exp_name, args.arch, 'camvid', episode))           
        
        utils.adjust_learning_rate(optimizer, args.l_rate, args.lr_decay, 
                                     epoch - 1, 1)
        
        if epoch % save_interval == 0:
            print 'data_size : ', len(dataset)
            state = {
                'episode' : episode,
                'epoch': epoch,
                'arch': args.arch,
                'loss': loss.data[0],
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, 'checkpoint/{}/{}_{}_{}.pth.tar'.format(\
                                   args.exp_name, args.arch, episode, epoch))
            print("Epoch [%d/%d] Loss: %.4f  lr:%.4f" %
                    (epoch, n_epoch, loss.data[0], optimizer.param_groups[0]['lr'] )) 
            t2 = time.time()
            print save_interval, 'epoch time :', t2 - t1
            t1 = time.time()

        if args.visdom:
            vis.line(
                X=np.column_stack((np.ones((1,)) * epoch)),
                Y=np.column_stack((np.array([loss.data[0]]))),
                win=loss_window,
                update='append')
    return model, optimizer