Пример #1
0
def test(val_loader, net, criterion):
    net.eval()

    val_loss = 0
    cm_py = torch.zeros(
        (val_loader.dataset.num_classes,
         val_loader.dataset.num_classes)).type(torch.IntTensor).cuda()
    for vi, data in enumerate(val_loader):
        inputs, gts_, _ = data
        with torch.no_grad():
            inputs = Variable(inputs).cuda()
            gts = Variable(gts_).cuda()

        outputs, _ = net(inputs)
        predictions_py = outputs.data.max(1)[1].squeeze_(1)
        loss = criterion(outputs, gts)
        vl_loss = loss.item()
        val_loss += (vl_loss)

        cm_py = confusion_matrix_pytorch(cm_py, predictions_py.view(-1),
                                         gts_.cuda().view(-1),
                                         val_loader.dataset.num_classes)

        len_val = len(val_loader)
        progress_bar(vi, len_val, '[val loss %.5f]' % (val_loss / (vi + 1)))

        del (outputs)
        del (vl_loss)
    acc, mean_iu, iu = evaluate(cm_py.cpu().numpy())
    print(' ')
    print(' [val acc %.5f], [val iu %.5f]' % (acc, mean_iu))

    return val_loss / len(val_loader), acc, mean_iu, iu
Пример #2
0
	def run_threads(threads, sending, completed, total):
		# Run threads
		for thread in threads:
			sending += 1 # Sending
			progress_bar(sending, completed, total)
			thread.start()

		# Wait for threads completed
		for thread in threads:
			completed += 1
			progress_bar(sending, completed, total)
			thread.join()

		return sending, completed
Пример #3
0
def train(train_loader, net, criterion, optimizer, supervised=False):
    net.train()
    train_loss = 0
    cm_py = torch.zeros(
        (train_loader.dataset.num_classes,
         train_loader.dataset.num_classes)).type(torch.IntTensor).cuda()
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        if supervised:
            im_s, t_s_, _ = data
        else:
            im_s, t_s_, _, _, _ = data

        t_s, im_s = Variable(t_s_).cuda(), Variable(im_s).cuda()
        # Get output of network
        outputs, _ = net(im_s)
        # Get segmentation maps
        predictions_py = outputs.data.max(1)[1].squeeze_(1)
        loss = criterion(outputs, t_s)
        train_loss += loss.item()

        loss.backward()
        nn.utils.clip_grad_norm_(net.parameters(), max_norm=4)
        optimizer.step()

        cm_py = confusion_matrix_pytorch(cm_py, predictions_py.view(-1),
                                         t_s_.cuda().view(-1),
                                         train_loader.dataset.num_classes)

        progress_bar(i, len(train_loader),
                     '[train loss %.5f]' % (train_loss / (i + 1)))

        del (outputs)
        del (loss)
        gc.collect()
    print(' ')
    acc, mean_iu, iu = evaluate(cm_py.cpu().numpy())
    print(' [train acc %.5f], [train iu %.5f]' % (acc, mean_iu))
    return train_loss / (len(train_loader)), 0, acc, mean_iu
Пример #4
0
def optimize_q_network(
    args,
    memory,
    Transition,
    policy_net,
    target_net,
    optimizerP,
    BATCH_SIZE=32,
    GAMMA=0.999,
    dqn_epochs=1,
):
    """This function optimizes the policy network

    :(ReplayMemory) memory: Experience replay buffer
    :param Transition: definition of the experience replay tuple
    :param policy_net: Policy network
    :param target_net: Target network
    :param optimizerP: Optimizer of the policy network
    :param BATCH_SIZE: (int) Batch size to sample from the experience replay
    :param GAMMA: (float) Discount factor
    :param dqn_epochs: (int) Number of epochs to train the DQN
    """
    # Code adapted from https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
    if len(memory) < BATCH_SIZE:
        return
    print("Optimize model...")
    print(len(memory))
    policy_net.train()
    loss_item = 0
    for ep in range(dqn_epochs):
        optimizerP.zero_grad()
        transitions = memory.sample(BATCH_SIZE)
        # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
        # detailed explanation).
        batch = Transition(*zip(*transitions))
        # Compute a mask of non-final states and concatenate the batch elements
        non_final_mask = torch.tensor(tuple(
            map(lambda s: s is not None, batch.next_state)),
                                      dtype=torch.uint8).cuda()

        non_final_next_states = torch.cat(
            [s for s in batch.next_state if s is not None])
        non_final_next_states_subset = torch.cat(
            [s for s in batch.next_state_subset if s is not None])

        state_batch = torch.cat(batch.state)
        state_batch_subset = torch.cat(batch.state_subset)
        action_batch = torch.Tensor([batch.action
                                     ]).view(-1).type(torch.LongTensor)
        reward_batch = torch.Tensor([batch.reward]).view(-1).cuda()
        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken
        q_val = policy_net(state_batch.cuda(), state_batch_subset.cuda())
        state_batch.cpu()
        state_batch_subset.cpu()
        state_action_values = q_val.gather(1, action_batch.unsqueeze(1).cuda())
        action_batch.cpu()
        # Compute V(s_{t+1}) for all next states.
        next_state_values = torch.zeros(BATCH_SIZE).cuda()
        # Double dqn, so we compute the values with the target network, we choose the actions with the policy network
        if non_final_mask.sum().item() > 0:
            v_val_act = policy_net(
                non_final_next_states.cuda(),
                non_final_next_states_subset.cuda()).detach()
            v_val = target_net(non_final_next_states.cuda(),
                               non_final_next_states_subset.cuda()).detach()
            non_final_next_states.cpu()
            non_final_next_states_subset.cpu()
            act = v_val_act.max(1)[1]
            next_state_values[non_final_mask] = v_val.gather(
                1, act.unsqueeze(1)).view(-1)
        # Compute the expected Q values
        expected_state_action_values = (next_state_values *
                                        GAMMA) + reward_batch
        # Compute Huber loss
        loss = F.smooth_l1_loss(state_action_values.view(-1),
                                expected_state_action_values)
        loss_item += loss.item()
        progress_bar(ep, dqn_epochs,
                     "[DQN loss %.5f]" % (loss_item / (ep + 1)))
        loss.backward()
        optimizerP.step()

        del q_val
        del expected_state_action_values
        del loss
        del next_state_values
        del reward_batch
        if non_final_mask.sum().item() > 0:
            del act
            del v_val
            del v_val_act
        del state_action_values
        del state_batch
        del action_batch
        del non_final_mask
        del non_final_next_states
        del batch
        del transitions
    lab_set = open(os.path.join(args.ckpt_path, args.exp_name, "q_loss.txt"),
                   "a")
    lab_set.write("%f" % (loss_item))
    lab_set.write("\n")
    lab_set.close()
Пример #5
0
def validate(val_loader,
             net,
             criterion,
             optimizer,
             epoch,
             best_record,
             args,
             final_final_test=False):
    net.eval()

    val_loss = 0
    cm_py = torch.zeros(
        (val_loader.dataset.num_classes,
         val_loader.dataset.num_classes)).type(torch.IntTensor).cuda()
    for vi, data in enumerate(val_loader):
        inputs, gts_, _ = data
        with torch.no_grad():
            inputs = Variable(inputs).cuda()
            gts = Variable(gts_).cuda()
        outputs, _ = net(inputs)
        # Make sure both output and target have the same dimensions
        if outputs.shape[2:] != gts.shape[1:]:
            outputs = outputs[:, :, 0:min(outputs.shape[2], gts.shape[1]),
                              0:min(outputs.shape[3], gts.shape[2])]
            gts = gts[:, 0:min(outputs.shape[2], gts.shape[1]),
                      0:min(outputs.shape[3], gts.shape[2])]
        predictions_py = outputs.data.max(1)[1].squeeze_(1)
        loss = criterion(outputs, gts)
        vl_loss = loss.item()
        val_loss += (vl_loss)

        cm_py = confusion_matrix_pytorch(cm_py, predictions_py.view(-1),
                                         gts_.cuda().view(-1),
                                         val_loader.dataset.num_classes)

        len_val = len(val_loader)
        progress_bar(vi, len_val, '[val loss %.5f]' % (val_loss / (vi + 1)))

        del (outputs)
        del (vl_loss)
        del (loss)
        del (predictions_py)
    acc, mean_iu, iu = evaluate(cm_py.cpu().numpy())
    print(' ')
    print(' [val acc %.5f], [val iu %.5f]' % (acc, mean_iu))

    if not final_final_test:
        if mean_iu > best_record['mean_iu']:
            best_record['val_loss'] = val_loss / len(val_loader)
            best_record['epoch'] = epoch
            best_record['acc'] = acc
            best_record['iu'] = iu
            best_record['mean_iu'] = mean_iu

            torch.save(
                net.cpu().state_dict(),
                os.path.join(args.ckpt_path, args.exp_name,
                             'best_jaccard_val.pth'))
            net.cuda()
            torch.save(
                optimizer.state_dict(),
                os.path.join(args.ckpt_path, args.exp_name,
                             'opt_best_jaccard_val.pth'))

        ## Save checkpoint every epoch
        torch.save(
            net.cpu().state_dict(),
            os.path.join(args.ckpt_path, args.exp_name,
                         'last_jaccard_val.pth'))
        net.cuda()
        torch.save(
            optimizer.state_dict(),
            os.path.join(args.ckpt_path, args.exp_name,
                         'opt_last_jaccard_val.pth'))

        print('best record: [val loss %.5f], [acc %.5f], [mean_iu %.5f],'
              ' [epoch %d]' % (best_record['val_loss'], best_record['acc'],
                               best_record['mean_iu'], best_record['epoch']))

    print('----------------------------------------')

    return val_loss / len(val_loader), acc, mean_iu, iu, best_record