예제 #1
0
def test(model, args):
    root = '../dataset/test'
    test_img = Data(root)
    test_loader = torch.utils.data.DataLoader(test_img,
                                              batch_size=1,
                                              shuffle=False)
    save_dir = args.save_dir
    train_dataset = args.train_dataset
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if args.cuda and IS_CUDA:
        model.cuda()
    model.eval()
    start_time = time.time()
    all_t = 0

    for i, data in enumerate(test_loader):
        if args.cuda and IS_CUDA:
            data = data.cuda()
        x_path, x = data
        tm = time.time()
        out = model(x.float())
        fuse = F.sigmoid(out[-1]).cpu().data.numpy()[0, 0, :, :]
        if not os.path.exists(
                os.path.join(save_dir, 'result_%s' % train_dataset)):
            os.mkdir(os.path.join(save_dir, 'result_%s' % train_dataset))
        cv2.imwrite(
            os.path.join(save_dir, 'result_%s' % train_dataset, '%s' % x_path),
            fuse * 255)
        print(x_path[0])
        all_t += time.time() - tm
    # print(all_t)
    print('Overall Time use: ', time.time() - start_time)
예제 #2
0
def test(model, args):
    test_root = cfg.config_test[args.dataset]['data_root']
    # test_root = cfg.config[args.dataset]['data_root']
    test_lst = cfg.config_test[args.dataset]['data_lst']
    # test_lst = cfg.config[args.dataset]['data_lst']
    # test_name_lst = os.path.join(test_root, 'train_pair.lst')   #'voc_valtest.txt'
    test_name_lst = os.path.join(test_root, 'test.lst')
    if 'Multicue' in args.dataset:
        test_lst = test_lst % args.k
        test_name_lst = os.path.join(test_root, 'test%d_id.txt' % args.k)
    mean_bgr = np.array(cfg.config_test[args.dataset]['mean_bgr'])
    test_img = Data(test_root, test_lst, mean_bgr=mean_bgr)
    testloader = torch.utils.data.DataLoader(test_img,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=0)  #num_workers=8
    nm = np.loadtxt(test_name_lst, dtype=str)
    print(len(testloader), len(nm))
    assert len(testloader) == len(nm)
    save_res = True
    save_dir = args.res_dir
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if args.cuda:
        model.cuda()
    model.eval()
    data_iter = iter(testloader)
    iter_per_epoch = len(testloader)
    start_time = time.time()
    all_t = 0
    # print("data_iter: ", data_iter.__next__())
    print(testloader)
    for i, (data, _) in enumerate(testloader):
        print("index: ", i)
        if args.cuda:
            data = data.cuda()
        data = Variable(data)  #, volatile=True
        tm = time.time()
        out = model(data)
        fuse = torch.sigmoid(out[-1]).cpu().data.numpy()[0, 0, :, :]
        if not os.path.exists(os.path.join(save_dir, 'fuse')):
            os.mkdir(os.path.join(save_dir, 'fuse'))
        try:
            # print(fuse)
            pic = Image.fromarray(fuse * 255)
            pic = pic.convert('L')
            # print(pic)
            pic.save(os.path.join(save_dir, 'fuse', '%s.png' % nm[i][0][:-4]),
                     "PNG")
            # cv2.imwrite(os.path.join(save_dir, 'fuse', '%s.png'%nm[i][0]), 255-fuse*255)
        except Exception as e:
            print(e)
            print("not write", i)

        all_t += time.time() - tm

    print(all_t)
    print('Overall Time use: ', time.time() - start_time)
예제 #3
0
파일: test_ms.py 프로젝트: xavysp/BDCN_xsp
def test(model, args,running_on='cpu'):
    test_root = cfg.config_test[args.test_data]['data_root']
    test_lst = cfg.config_test[args.test_data]['data_lst']
    test_name_lst = os.path.join(test_root, 'voc_valtest.txt')
    # if 'Multicue' in args.dataset:
    #     test_lst = test_lst % args.k
    #     test_name_lst = os.path.join(test_root, 'test%d_id.txt'%args.k)
    mean_bgr = np.array(cfg.config_test[args.test_data]['mean_bgr'])
    test_img = Data(test_root, test_lst, mean_bgr=mean_bgr, is_train=False, dataset_name=args.test_data,
                    scale=[0.5, 1, 1.5])
    testloader = torch.utils.data.DataLoader(
        test_img, batch_size=1, shuffle=False, num_workers=8)
    assert len(test_img ) >0

    base_dir = args.res_dir
    dataset_save_dir = os.path.join('edges', args.model_name + '_' + args.train_data + str(2) + args.test_data)

    save_dir = os.path.join(base_dir, dataset_save_dir)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if args.cuda:
        model.cuda()
    model.eval()
    data_iter = iter(testloader)
    iter_per_epoch = len(testloader)
    start_time = time.time()
    all_t = []

    with torch.no_grad():

        for i, (ms_data, label) in enumerate(testloader):
            ms_fuse = np.zeros((label.size()[2], label.size()[3]))
            tm = time.time()
            for data in ms_data:
                if args.cuda:
                    data = data.cuda()
                data = Variable(data, volatile=True)
                out = model(data)
                fuse = torch.sigmoid(out[-1]).cpu().data.numpy()[0, 0, :, :]
                fuse = cv2.resize(fuse, (label.size()[3], label.size()[2]), interpolation=cv2.INTER_LINEAR)
                ms_fuse += fuse
            ms_fuse /= len(ms_data)
            # tm = time.time()
            if not os.path.exists(os.path.join(save_dir, 'ms_pred')):
                os.mkdir(os.path.join(save_dir, 'ms_pred'))
            name = testloader.dataset.images_name[i]
            cv2.imwrite(os.path.join(save_dir, 'ms_pred', '%s.png'%name), np.uint8(255-ms_fuse*255))
            all_t.append(time.time() - tm)
            print('Done: ',name,'in ', i+1)
    all_t = np.array(all_t)
    print('Average time per image: ', all_t.mean())
    print ('Overall Time use: ', time.time() - start_time)
예제 #4
0
def test(model, args):
    test_root = cfg.config_test[args.dataset]['data_root']
    test_lst = cfg.config_test[args.dataset]['data_lst']
    test_name_lst = os.path.join(test_root, 'test.lst')
    mean_bgr = np.array(cfg.config_test[args.dataset]['mean_bgr'])

    test_img = Data(test_root, test_lst, mean_bgr=mean_bgr)
    testloader = torch.utils.data.DataLoader(test_img,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1)
    nm = np.loadtxt(test_name_lst, dtype=str)

    save_dir = os.path.join(args.res_dir,
                            args.model.split('/')[-1].split('.')[0] + '_fuse')

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    if args.cuda:
        model.cuda()

    model.eval()
    start_time = time.time()
    all_t = 0
    for i, (data, _) in enumerate(testloader):
        if args.cuda:
            data = data.cuda()
        data = Variable(data)
        tm = time.time()

        with torch.no_grad():
            out = model(data)

        fuse = torch.sigmoid(out[-1]).cpu().numpy()[0, 0, :, :]

        if not os.path.exists(os.path.join(save_dir, 'fuse')):
            os.mkdir(os.path.join(save_dir, 'fuse'))

        fuse = fuse * 255
        fuse = Image.fromarray(fuse).convert('RGB')
        fuse.save(
            os.path.join(save_dir, 'fuse', '{}.png'.format(
                nm[i][0].split('/')[2].split('.')[0])))
        all_t += time.time() - tm
    print('Save prediction into folder {}'.format(
        str(os.path.join(save_dir, 'fuse'))))
    print('Overall Time use: ', time.time() - start_time)
예제 #5
0
def test(model, args):
    test_root = cfg.config_test[args.dataset]['data_root']
    test_lst = cfg.config_test[args.dataset]['data_lst']
    test_name_lst = os.path.join(test_root, 'test_pair2.lst')
    if 'Multicue' in args.dataset:
        test_lst = test_lst % args.k
        test_name_lst = os.path.join(test_root, 'test%d_id.txt' % args.k)
    mean_bgr = np.array(cfg.config_test[args.dataset]['mean_bgr'])
    test_img = Data(test_root,
                    test_lst,
                    mean_bgr=mean_bgr,
                    scale=[0.5, 1, 1.5])
    testloader = torch.utils.data.DataLoader(test_img,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=8)
    nm = np.loadtxt(test_name_lst, dtype=str)
    assert len(testloader) == len(nm)

    save_dir = args.res_dir
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if args.cuda:
        model.cuda()
    model.eval()
    data_iter = iter(testloader)
    iter_per_epoch = len(testloader)
    start_time = time.time()
    for i, (ms_data, label) in enumerate(testloader):
        ms_fuse = np.zeros((label.size()[2], label.size()[3]))
        for data in ms_data:
            if args.cuda:
                data = data.cuda()
            data = Variable(data, volatile=True)
            out = model(data)
            fuse = F.sigmoid(out[-1]).cpu().data.numpy()[0, 0, :, :]
            fuse = cv2.resize(fuse, (label.size()[3], label.size()[2]),
                              interpolation=cv2.INTER_LINEAR)
            ms_fuse += fuse
        ms_fuse /= len(ms_data)
        if not os.path.exists(os.path.join(save_dir, 'fuse')):
            os.mkdir(os.path.join(save_dir, 'fuse'))
        cv2.imwrite(os.path.join(save_dir, 'fuse', '%s.jpg' % nm[i]),
                    255 - ms_fuse * 255)
    print('Overall Time use: ', time.time() - start_time)
예제 #6
0
def test(model, args):
    test_root = '/home/liu/桌面/pic'  # cfg.config_test[args.dataset]['data_root']
    test_lst = cfg.config_test[args.dataset]['data_lst']
    test_name_lst = os.path.join(test_root, 'voc_valtest.txt')
    # test_name_lst = os.path.join(test_root, 'test_id.txt')
    if 'Multicue' in args.dataset:
        test_lst = test_lst % args.k
        test_name_lst = os.path.join(test_root, 'test%d_id.txt' % args.k)
    mean_bgr = np.array(cfg.config_test[args.dataset]['mean_bgr'])
    test_img = Data(test_root, test_lst, mean_bgr=mean_bgr)
    testloader = torch.utils.data.DataLoader(test_img,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=8)
    nm = np.loadtxt(test_name_lst, dtype=str)
    print(len(testloader), len(nm))
    assert len(testloader) == len(nm)
    save_res = True
    save_dir = args.res_dir
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if args.cuda:
        model.cuda()
    model.eval()
    data_iter = iter(testloader)
    iter_per_epoch = len(testloader)
    start_time = time.time()
    all_t = 0
    for i, (data, _) in enumerate(testloader):
        if args.cuda:
            data = data.cuda()
        data = Variable(data, volatile=True)
        tm = time.time()
        out = model(data)
        fuse = F.sigmoid(out[-1]).cpu().data.numpy()[0, 0, :, :]
        if not os.path.exists(os.path.join(save_dir, 'fuse')):
            os.mkdir(os.path.join(save_dir, 'fuse'))
        cv2.imwrite(os.path.join(save_dir, 'fuse', '%s.png' % nm[i]),
                    255 - fuse * 255)
        all_t += time.time() - tm
    print(all_t)
    print('Overall Time use: ', time.time() - start_time)
예제 #7
0
def test(model, args):
    test_root = cfg.config_test[args.dataset]['data_root']
    test_lst = cfg.config_test[args.dataset]['data_lst']
    test_name_lst = os.path.join(test_root, test_lst)

    if 'Multicue' in args.dataset:
        test_lst = test_lst % args.k

    mean_bgr = np.array(cfg.config_test[args.dataset]['mean_bgr'])

    test_img = Data(test_root, test_lst, 0.5, mean_bgr=mean_bgr)
    testloader = torch.utils.data.DataLoader(test_img,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=8)
    lst = np.loadtxt(test_name_lst, dtype=str)[:, 0]
    nm = [osp.splitext(osp.split(x)[-1])[0] for x in lst]
    save_dir = args.res_dir
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if args.cuda:
        model.cuda()
    model.eval()
    data_iter = iter(testloader)
    iter_per_epoch = len(testloader)
    start_time = time.time()
    all_t = 0
    for i, (data, _) in enumerate(testloader):
        if args.cuda:
            data = data.cuda()
        data = Variable(data, volatile=True)
        t1 = time.time()
        out = model(data)
        t = F.sigmoid(out[-1]).cpu().data.numpy()[0, 0, :, :]
        if not os.path.exists(os.path.join(save_dir, 'fuse')):
            os.mkdir(os.path.join(save_dir, 'fuse'))
        cv2.imwrite(os.path.join(save_dir, 'fuse', '%s.jpg' % nm[i]),
                    255 - t * 255)
        all_t += time.time() - t1

    print(all_t)
    print('Overall Time use: ', time.time() - start_time)
예제 #8
0
def train(model, args):
    data_root = cfg.config[args.dataset]['data_root']
    data_lst = cfg.config[args.dataset]['data_lst']
    if 'Multicue' in args.dataset:
        data_lst = data_lst % args.k
    mean_bgr = np.array(cfg.config[args.dataset]['mean_bgr'])
    yita = args.yita if args.yita else cfg.config[args.dataset]['yita']
    crop_size = args.crop_size
    train_img = Data(data_root,
                     data_lst,
                     yita,
                     mean_bgr=mean_bgr,
                     crop_size=crop_size)
    trainloader = torch.utils.data.DataLoader(train_img,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=5)
    params_dict = dict(model.named_parameters())
    base_lr = args.base_lr
    weight_decay = args.weight_decay
    logger = args.logger
    params = []
    for key, v in params_dict.items():
        if re.match(r'conv[1-5]_[1-3]_down', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*conv[1-4]_[1-3]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*conv5_[1-3]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 100,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 200,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'score_dsn[1-5]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.01,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.02,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'upsample_[248](_5)?', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*msblock[1-5]_[1-3]\.conv', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        else:
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.001,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.002,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
    optimizer = torch.optim.SGD(params,
                                momentum=args.momentum,
                                lr=args.base_lr,
                                weight_decay=args.weight_decay)
    start_step = 1
    mean_loss = []
    cur = 0
    pos = 0
    data_iter = iter(trainloader)
    iter_per_epoch = len(trainloader)
    logger.info('*' * 40)
    logger.info('train images in all are %d ' % iter_per_epoch)
    logger.info('*' * 40)

    start_time = time.time()
    if args.cuda:
        model.cuda()
    if args.resume:
        logger.info('resume from %s' % args.resume)
        state = torch.load(args.resume)
        start_step = state['step']
        optimizer.load_state_dict(state['solver'])
    model.train()
    batch_size = args.iter_size * args.batch_size
    for step in xrange(start_step, args.max_iter + 1):
        optimizer.zero_grad()
        batch_loss = 0
        for i in xrange(args.iter_size):
            if cur == iter_per_epoch:
                cur = 0
                data_iter = iter(trainloader)
            images, labels = next(data_iter)
            if args.cuda:
                images, labels = images.cuda(), labels.cuda()
            images, labels = Variable(images), Variable(labels)
            out = model(images)
            loss = 0
            for k in range(len(out) - 1):
                loss += args.side_weight * cross_entropy_loss2d(
                    out[k], labels, args.cuda, args.balance) / batch_size
            loss += args.fuse_weight * cross_entropy_loss2d(
                out[-1], labels, args.cuda, args.balance) / batch_size
            loss.backward()
            batch_loss += loss.data[0]
            cur += 1
        # update parameter
        optimizer.step()
        if len(mean_loss) < args.average_loss:
            mean_loss.append(batch_loss)
        else:
            mean_loss[pos] = batch_loss
            pos = (pos + 1) % args.average_loss
        if step % args.step_size == 0:
            adjust_learning_rate(optimizer, step, args.step_size)
        if step % args.snapshots == 0:
            torch.save(model.state_dict(),
                       '%s/bdcn_%d.pth' % (args.param_dir, step))
            # state = {'step': step+1,'param':model.state_dict(),'solver':optimizer.state_dict()}
            # torch.save(state, '%s/bdcn_%d.pth.tar' % (args.param_dir, step))
        if step % args.display == 0:
            tm = time.time() - start_time
            logger.info(
                'iter: %d, lr: %e, loss: %f, time using: %f(%fs/iter)' %
                (step, optimizer.param_groups[0]['lr'], np.mean(mean_loss), tm,
                 tm / args.display))
            start_time = time.time()
예제 #9
0
def train(model, args):
    logger = args.logger

    # Training dataloader
    data_root = cfg.config[args.dataset]['data_root']
    data_lst = cfg.config[args.dataset]['data_lst']
    mean_bgr = np.array(cfg.config[args.dataset]['mean_bgr'])
    train_img = Data(data_root, data_lst, mean_bgr=mean_bgr)
    trainloader = torch.utils.data.DataLoader(train_img,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=20)
    n_train = len(trainloader)

    # adam optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.base_lr,
                                 weight_decay=args.weight_decay)

    # Validation dataloader
    val_root = cfg.config_val[args.dataset]['data_root']
    val_lst = cfg.config_val[args.dataset]['data_lst']
    mean_bgr = np.array(cfg.config_val[args.dataset]['mean_bgr'])
    val_img = Data(val_root, val_lst, mean_bgr=mean_bgr)
    valloader = torch.utils.data.DataLoader(val_img,
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=20)

    start_time = time.time()
    if args.cuda:
        model.cuda()
    if args.resume:
        state = torch.load(args.resume)
        optimizer.load_state_dict(state['solver'])
        model.load_state_dict(state['param'])

    batch_size = args.batch_size

    # Writer will output to ./runs/ directory by default
    writer = SummaryWriter()
    epochs = args.epochs
    batch_index = 0
    for epoch in range(epochs):
        # Train mode
        model.train()
        mean_loss = []
        val_mean_loss = []
        batch_loss = 0
        pos = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for images, labels in trainloader:
                # Set the gradient in the model into 0
                optimizer.zero_grad()

                if batch_index != batch_size:
                    # If batchsize not equal to batch index , calculate the current loss
                    if args.cuda:
                        images, labels = images.cuda(), labels.cuda()
                    images, labels = Variable(images), Variable(labels)
                    out = model(images)
                    loss = 0
                    # Loss function for 10 different intermediate output
                    for k in range(10):
                        loss += args.side_weight * cross_entropy_loss2d(
                            out[k], labels, args.cuda,
                            args.balance) / (10 * batch_size)

                    # loss function for fuse output
                    loss += args.fuse_weight * cross_entropy_loss2d(
                        out[-1], labels, args.cuda,
                        args.balance) / (10 * batch_size)

                    # Back calculating loss
                    loss.backward()

                    batch_index += 1

                    # Update batch loss
                    batch_loss += loss

                if batch_index == batch_size:
                    # If batchsize equal to batch index , backward the loss and update the loss function
                    # Set batch index to 0
                    batch_index = 0

                    # update parameter, gradient descent, back propagation
                    optimizer.step()

                    # Update the pbar
                    pbar.update(images.shape[0])

                    # Append loss into mean_loss list
                    # Smooth the loss
                    if len(mean_loss) < args.average_loss:
                        mean_loss.append(batch_loss)
                    else:
                        mean_loss[pos] = batch_loss
                        pos = (pos + 1) % args.average_loss

                    batch_loss = 0

                # Add loss (batch) value to tqdm
                pbar.set_postfix(**{'loss (batch)': loss.item()})

        # Adjust learning rate
        if (epoch + 1) % args.step_size == 0:
            adjust_learning_rate(optimizer, epoch + 1, args.step_size,
                                 args.gamma)

        # Save BDCN weights
        if (epoch + 1) % args.snapshots == 0:
            torch.save(model.state_dict(),
                       '%s/bdcn_%d.pth' % (args.param_dir, epoch + 1))
            state = {
                'step': epoch + 1,
                'param': model.state_dict(),
                'solver': optimizer.state_dict()
            }
            torch.save(state,
                       '%s/bdcn_%d.pth.tar' % (args.param_dir, epoch + 1))

        tm = time.time() - start_time

        # Evaluate mode
        model.eval()

        for val_images, val_labels in valloader:
            if args.cuda:
                val_images, val_labels = val_images.cuda(), val_labels.cuda()
            val_images, val_labels = Variable(val_images), Variable(val_labels)

            with torch.no_grad():
                out = model(val_images)

            val_loss = 0
            for k in range(10):
                val_loss += args.side_weight * cross_entropy_loss2d(
                    out[k], val_labels, args.cuda, args.balance) / 10
            val_loss += args.fuse_weight * cross_entropy_loss2d(
                out[-1], val_labels, args.cuda, args.balance) / 10

            val_mean_loss.append(val_loss)

        # Add scalar to tensorboard Loss/Train
        writer.add_scalars(
            'Loss/train/val', {
                'Train loss': torch.mean(torch.stack(mean_loss)),
                'Validation loss': torch.mean(torch.stack(val_mean_loss))
            }, epoch)

        logger.info('lr: %e, loss: %f, validation loss: %f, time using: %f' %
                    (optimizer.param_groups[0]['lr'],
                     torch.mean(torch.stack(mean_loss)),
                     torch.mean(torch.stack(val_mean_loss)), tm))

        start_time = time.time()
예제 #10
0
파일: train.py 프로젝트: huberthomas/bdcn
def train(model, args):
    logger = args.logger
    data_root = cfg.config[args.dataset]['data_root']
    data_lst = cfg.config[args.dataset]['data_lst']

    val_lst = None
    if 'validation' in cfg.config[args.dataset] and cfg.config[
            args.dataset]['validation'] == 1:
        if 'val_lst' in cfg.config[args.dataset] and os.path.exists(
                os.path.join(data_root, cfg.config[args.dataset]['val_lst'])):
            logger.info('Loading validation set from %s' % (val_lst))
            val_lst = cfg.config[args.dataset]['val_lst']
        else:
            logger.info(
                'Automatically generating training and validation set.')
            data_lst, val_lst = createValidationLst(data_root, data_lst)
            logger.info('Finished and stored in %s and %s.' %
                        (data_lst, val_lst))

    if 'Multicue' in args.dataset:
        data_lst = data_lst % args.k

    mean_bgr = np.array(cfg.config[args.dataset]['mean_bgr'])
    yita = args.yita if args.yita else cfg.config[args.dataset]['yita']
    crop_size = args.crop_size
    crop_padding = args.crop_padding

    valloader = None
    if val_lst is not None:
        logger.info('Validation: enabled')
        val_img = Data(data_root,
                       val_lst,
                       yita,
                       mean_bgr=mean_bgr,
                       crop_size=crop_size,
                       shuffle=True,
                       crop_padding=crop_padding)
        valloader = torch.utils.data.DataLoader(val_img,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=8,
                                                drop_last=True)
    else:
        logger.info('Validation: disabled')

    train_img = Data(data_root,
                     data_lst,
                     yita,
                     mean_bgr=mean_bgr,
                     crop_size=crop_size,
                     shuffle=True,
                     crop_padding=crop_padding,
                     flip=True,
                     brightness=True,
                     blur=True,
                     rotate=True)
    trainloader = torch.utils.data.DataLoader(train_img,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=8,
                                              drop_last=True)

    params_dict = dict(model.named_parameters())
    base_lr = args.base_lr
    weight_decay = args.weight_decay
    params = []
    for key, v in params_dict.items():
        # regular expression match
        if re.match(r'conv[1-5]_[1-3]_down', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*conv[1-4]_[1-3]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*conv5_[1-3]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 100,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 200,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'score_dsn[1-5]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.01,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.02,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'upsample_[248](_5)?', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*msblock[1-5]_[1-3]\.conv', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        else:
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.001,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.002,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]

    optimizer = torch.optim.SGD(params,
                                momentum=args.momentum,
                                lr=args.base_lr,
                                weight_decay=args.weight_decay)
    #optimizer = torch.optim.Adam(params, lr=args.base_lr, weight_decay=args.weight_decay)
    start_step = 1
    mean_loss = []
    mean_loss_lst = []
    val_mean_loss = []
    cur = 0
    pos = 0
    val_pos = 0
    data_iter = iter(trainloader)
    iter_per_epoch = len(trainloader)
    logger.info('*' * 40)
    logger.info('train images in all are %d ' % iter_per_epoch)
    logger.info('*' * 40)
    for param_group in optimizer.param_groups:
        if logger:
            logger.info('%s: %s' % (param_group['name'], param_group['lr']))
    start_time = time.time()
    if args.cuda:
        model.cuda()
    if args.resume:
        logger.info('resume from %s' % args.resume)
        state = torch.load(args.resume)
        logger.info('*' * 40)
        # for x in state.__dict__:
        #     logger.info(x)
        logger.info('*' * 40)
        start_step = state['step']
        optimizer.load_state_dict(state['solver'])
        model.load_state_dict(state['param'])
    model.train()
    batch_size = args.iter_size * args.batch_size
    for step in xrange(start_step, args.max_iter + 1):
        optimizer.zero_grad()
        batch_loss = 0
        for i in xrange(args.iter_size):
            if cur == iter_per_epoch:
                cur = 0
                data_iter = iter(trainloader)

            images, labels = next(data_iter)

            if args.cuda:
                images, labels = images.cuda(), labels.cuda()

            images, labels = Variable(images), Variable(labels)

            out = model(images)

            loss = 0
            for k in xrange(10):
                loss += args.side_weight * cross_entropy_loss2d(
                    out[k], labels, args.cuda, args.balance) / batch_size

            loss += args.fuse_weight * cross_entropy_loss2d(
                out[-1], labels, args.cuda, args.balance) / batch_size
            loss.backward()
            batch_loss += loss.item()  #.data[0]
            cur += 1
        # update parameter
        optimizer.step()
        if len(mean_loss) < args.average_loss:
            mean_loss.append(batch_loss)
        else:
            mean_loss[pos] = batch_loss
            pos = (pos + 1) % args.average_loss
        if step % args.step_size == 0:
            adjust_learning_rate(optimizer, step, args.step_size, args.gamma)
        if step % args.snapshots == 0:
            torch.save(model.state_dict(),
                       '%s/bdcn_%d.pth' % (args.param_dir, step))
            state = {
                'step': step + 1,
                'param': model.state_dict(),
                'solver': optimizer.state_dict()
            }
            torch.save(state, '%s/bdcn_%d.pth.tar' % (args.param_dir, step))
        if step % args.display == 0:
            tm = time.time() - start_time
            logger.info(
                'iter: %d, lr: %e, loss: %f, time using: %f(%fs/iter)' %
                (step, optimizer.param_groups[0]['lr'], np.mean(mean_loss), tm,
                 tm / args.display))
            start_time = time.time()

        # VALIDATION
        if valloader is not None and step % args.val_step_size == 0:
            model.train(False)
            model.eval()
            logger.info('mode: validation')
            val_batch_loss = 0
            for i, data in enumerate(valloader):
                val_images, val_labels = data

                if args.cuda:
                    val_images, val_labels = val_images.cuda(
                    ), val_labels.cuda()

                val_images, val_labels = Variable(val_images), Variable(
                    val_labels)

                #optimizer.zero_grad()
                val_out = model(val_images)

                val_loss = 0
                for k in xrange(10):
                    val_loss += args.side_weight * cross_entropy_loss2d(
                        val_out[k], val_labels, args.cuda,
                        args.balance) / batch_size

                val_loss += args.fuse_weight * cross_entropy_loss2d(
                    val_out[-1], val_labels, args.cuda,
                    args.balance) / batch_size
                val_loss.backward()
                val_batch_loss += val_loss.item()  #.data[0]

            if len(val_mean_loss) < args.average_loss:
                val_mean_loss.append(val_batch_loss)
            else:
                val_mean_loss[val_pos] = val_batch_loss
                val_pos = (val_pos + 1) % args.average_loss

            mean_loss_lst.append(
                [step, np.mean(mean_loss),
                 np.mean(val_mean_loss)])

            for entry in mean_loss_lst:
                logger.info('iter: %d, loss: %f, val_loss: %f' %
                            (entry[0], entry[1], entry[2]))

            logger.info('mode: training')
            model.train()
예제 #11
0
파일: main.py 프로젝트: YacobBY/bdcn
def forwardAll(model, args):
    test_root = cfg.config_test[args.dataset]['data_root']

    if (args.inputDir is not None):
        test_root = args.inputDir

    logging.info('Processing: %s' % test_root)
    test_lst = cfg.config_test[args.dataset]['data_lst']

    imageFileNames = createDataList(test_root, test_lst)

    mean_bgr = np.array(cfg.config_test[args.dataset]['mean_bgr'])
    test_img = Data(test_root,
                    test_lst,
                    mean_bgr=mean_bgr,
                    shuffle=False,
                    crop_padding=0,
                    crop_size=None)
    testloader = torch.utils.data.DataLoader(test_img,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1)
    # nm = np.loadtxt(test_name_lst, dtype=str)
    # print(len(testloader), len(nm))
    # assert len(testloader) == len(nm)
    # save_res = True
    save_dir = join(test_root, args.res_dir)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    if args.cuda:
        model.cuda()

    model.eval()
    # data_iter = iter(testloader)
    # iter_per_epoch = len(testloader)
    start_time = time.time()
    all_t = 0
    timeRecords = open(join(save_dir, 'timeRecords.txt'), "w")
    timeRecords.write('# filename time[ms]\n')

    for i, (data, _) in enumerate(testloader):
        if args.cuda:
            data = data.cuda()

            with torch.no_grad():
                data = Variable(data)  #, volatile=True)
                tm = time.time()

                out = model(data)
                fuse = torch.sigmoid(out[-1]).cpu().data.numpy()[0, 0, :, :]

                elapsedTime = time.time() - tm
                timeRecords.write('%s %f\n' %
                                  (imageFileNames[i], elapsedTime * 1000))

                cv2.imwrite(os.path.join(save_dir, '%s' % imageFileNames[i]),
                            fuse * 255)

                all_t += time.time() - tm

    timeRecords.close()
    print(all_t)
    print('Overall Time use: ', time.time() - start_time)
예제 #12
0
파일: train.py 프로젝트: xavysp/BDCN_xsp
def train(model, args, device='gpu'):
    # if args.dataset.lower()=='ssmihd':
    data_root = cfg.config[args.dataset]['data_root']
    data_lst = cfg.config[args.dataset]['data_lst']
    # else:
    #     data_root = cfg.config[args.dataset]['data_root']
    #     data_lst = cfg.config[args.dataset]['data_lst']
    if 'Multicue' in args.dataset:
        data_lst = data_lst % args.k

    mean_bgr = np.array(cfg.config[args.dataset]['mean_bgr'])
    yita = args.yita if args.yita else cfg.config[args.dataset]['yita']
    crop_size = args.crop_size
    train_img = Data(data_root,
                     data_lst,
                     yita,
                     mean_bgr=mean_bgr,
                     crop_size=crop_size,
                     is_train=True,
                     dataset_name=args.dataset)
    trainloader = torch.utils.data.DataLoader(train_img,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=5)

    params_dict = dict(model.named_parameters())
    base_lr = args.base_lr
    weight_decay = args.weight_decay
    logger = args.logger
    params = []
    fig = plt.figure()
    for key, v in params_dict.items():
        if re.match(r'conv[1-5]_[1-3]_down', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*conv[1-4]_[1-3]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*conv5_[1-3]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 100,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 200,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'score_dsn[1-5]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.01,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.02,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'upsample_[248](_5)?', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*msblock[1-5]_[1-3]\.conv', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        else:
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.001,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.002,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
    optimizer = torch.optim.SGD(params,
                                momentum=args.momentum,
                                lr=args.base_lr,
                                weight_decay=args.weight_decay)
    # criterion = _cross_entropy_loss2d
    criterion = cross_entropy_loss2d
    # criterion = bdcn_loss2d
    # criterion = xavy_loss()
    start_step = 0
    mean_loss = []
    cur = 0
    pos = 0
    data_iter = iter(trainloader)
    iter_per_epoch = len(trainloader)
    logger.info('*' * 40)
    logger.info('train images in all are %d ' % iter_per_epoch)
    logger.info('*' * 40)
    for param_group in optimizer.param_groups:
        if logger:
            logger.info('%s: %s' % (param_group['name'], param_group['lr']))
    start_time = time.time()
    if device.__str__() == 'gpu':
        model.cuda()
    if args.resume and args.use_prev_trained:
        logger.info('resume from %s' % args.resume)
        state = torch.load(args.resume)
        start_step = state['step']  # me add 1
        optimizer.load_state_dict(state['solver'])
        model.load_state_dict(state['param'])
        print("Starting form previous trained weights", start_step)
    model.train()
    batch_size = args.iter_size * args.batch_size
    print('Python %s on %s' % (sys.version, sys.platform))
    # visualize training
    visu_res = 'results/training'
    os.makedirs(visu_res, exist_ok=True)
    l_weights = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.1]
    for step in range(start_step, args.max_iter + 1):
        optimizer.zero_grad()
        batch_loss = 0
        for i in range(args.iter_size):
            if cur == iter_per_epoch:
                cur = 0
                data_iter = iter(trainloader)
            images, labels = next(data_iter)
            if args.cuda:
                images, labels = images.cuda(), labels.cuda()
                loss = torch.zeros(1).cuda()
            else:
                loss = torch.zeros(1)
            images, labels = Variable(images), Variable(labels)
            out = model(images)
            # loss = 0
            loss = sum([
                criterion(preds, labels, l_w, args.cuda)
                for preds, l_w in zip(out, l_weights)
            ])  # BDCN
            # for k in range(10):
            #     loss += args.side_weight*cross_entropy_loss2d(out[k], labels, args.cuda, args.balance)/batch_size
            # loss += args.fuse_weight*cross_entropy_loss2d(out[-1], labels, args.cuda, args.balance)/batch_size
            loss.backward()
            # batch_loss += loss.data[0]
            batch_loss += loss.cpu().detach().numpy()
            cur += 1
        # update parameter
        optimizer.step()
        if len(mean_loss) < args.average_loss:
            mean_loss.append(batch_loss)
        else:
            mean_loss[pos] = batch_loss
            pos = (pos + 1) % args.average_loss
        if step % args.step_size == 0:
            adjust_learning_rate(optimizer, step, args.step_size, args.gamma)
        if step % args.snapshots == 0:
            torch.save(model.state_dict(),
                       '%s/bdcn_%d.pth' % (args.param_dir, step))
            state = {
                'step': step + 1,
                'param': model.state_dict(),
                'solver': optimizer.state_dict()
            }
            torch.save(state, '%s/bdcn_%d.pth.tar' % (args.param_dir, step))
        if step % args.display == 0:
            tm = time.time() - start_time
            print('iter: %d, lr: %e, loss: %f, time using: %f(%fs/iter)' %
                  (step, optimizer.param_groups[0]['lr'], np.mean(mean_loss),
                   tm, tm / args.display))
            start_time = time.time()
        if step % 30 == 0:
            res_data = []
            img = images.cpu().numpy()
            res_data.append(img)
            ed_gt = labels.cpu().numpy()
            res_data.append(ed_gt)
            for i in range(len(out)):
                tmp = out[i]
                tmp = torch.sigmoid(tmp)
                tmp = tmp.cpu().detach().numpy()
                res_data.append(tmp)
                # print('max', tmp.max())
                # print('min', tmp.min())
                # print('std', tmp.std())
            vis_imgs = visualize_result(res_data, arg=args)
            cv.imwrite(os.path.join(visu_res, 'curr_result.png'), vis_imgs)
            del tmp, res_data
            print("*****Visualization Epoch:" + str(step + 1) + " Loss:" +
                  '%.5f' % np.mean(mean_loss) + " training")
예제 #13
0
def train(model, args):
    data_root = cfg.config[args.dataset]['data_root']
    data_lst = cfg.config[args.dataset]['data_lst']
    if 'Multicue' in args.dataset:
        data_lst = data_lst % args.k
    mean_bgr = np.array(cfg.config[args.dataset]['mean_bgr'])
    yita = args.yita if args.yita else cfg.config[args.dataset]['yita']
    crop_size = args.crop_size
    train_img = Data(data_root,
                     data_lst,
                     yita,
                     mean_bgr=mean_bgr,
                     crop_size=crop_size)
    trainloader = torch.utils.data.DataLoader(train_img,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=0)  # num_workers=5

    params_dict = dict(model.named_parameters())
    base_lr = args.base_lr
    weight_decay = args.weight_decay
    logger = args.logger
    params = []
    for key, v in params_dict.items():
        if re.match(r'conv[1-5]_[1-3]_down', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*conv[1-4]_[1-3]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*conv5_[1-3]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 100,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 200,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'score_dsn[1-5]', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.01,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.02,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'upsample_[248](_5)?', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        elif re.match(r'.*msblock[1-5]_[1-3]\.conv', key):
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 1,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 2,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
        else:
            if 'weight' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.001,
                    'weight_decay': weight_decay * 1,
                    'name': key
                }]
            elif 'bias' in key:
                params += [{
                    'params': v,
                    'lr': base_lr * 0.002,
                    'weight_decay': weight_decay * 0,
                    'name': key
                }]
    # optimizer = torch.optim.SGD(params, momentum=args.momentum,
    #                             lr=args.base_lr, weight_decay=args.weight_decay)
    nm = np.loadtxt("../data/BSR/BSDS500/data/train_pair.lst", dtype=str)
    start_step = 1
    mean_loss = []
    cur = 0
    pos = 0
    data_iter = iter(trainloader)
    iter_per_epoch = len(trainloader)
    logger.info('*' * 40)
    logger.info('train images in all are %d ' % iter_per_epoch)
    logger.info('*' * 40)
    # for param_group in optimizer.param_groups:
    #     if logger:
    #         logger.info('%s: %s' % (param_group['name'], param_group['lr']))
    start_time = time.time()
    if args.cuda:
        model.cuda()
    if args.resume:
        logger.info('resume from %s' % args.resume)
        state = torch.load(args.resume)
        start_step = 40000
        # optimizer.load_state_dict(state['solver'])
        model.load_state_dict(state)
    model.eval()
    batch_size = args.iter_size * args.batch_size
    for step in range(start_step, args.max_iter + 1):
        # optimizer.zero_grad()
        batch_loss = 0
        # debug
        print("step:\t", step)
        for i in range(args.iter_size):
            if cur == iter_per_epoch:
                cur = 0
                data_iter = iter(trainloader)
            images, labels = next(data_iter)
            if args.cuda:
                images, labels = images.cuda(), labels.cuda()
            images, labels = Variable(images), Variable(labels)
            out = model(images)
            loss = 0

            side_weight = np.arange(args.side_weight, 1,
                                    (1 - args.side_weight) / 10)
            for k in range(10):
                loss += side_weight[i] * cross_entropy_loss2d(
                    out[k], labels, args.cuda, args.balance) / batch_size
            loss += args.fuse_weight * cross_entropy_loss2d(out[-1], labels, args.cuda, args.balance) / batch_size \
                    + args.reDice_weight * re_Dice_Loss(out[-1], labels, args.cuda, args.balance) / batch_size

            loss.backward()
            print("loss: ", loss.item())
            pic = Image.fromarray(images.cpu().data.numpy()[0, 0, :, :] * 255)
            pic = pic.convert('L')
            pic.save(
                os.path.join("./results", 'fuse/images/train',
                             '%s.png' % loss.item()), "PNG")
            batch_loss += loss.item(
            )  # loss.data[0]  #not suitable for pytorch0.5
            cur += 1
        # update parameter
        """"---------------------------------commented for curriculum learning--------------------------------------"""