def train(trainloader, net, index, optimizer, epoch, use_cuda):
    losses = AverageMeter()

    print('\nIndex: %d \t Epoch: %d' % (index, epoch))

    net.train()

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs = inputs.cuda()
        optimizer.zero_grad()
        inputs_Var = Variable(inputs)
        outputs = net(inputs_Var, index)
        # import ipdb; ipdb.set_trace()  # XXX BREAKPOINT

        # record loss
        losses.update(outputs.data[0], inputs.size(0))

        outputs.backward()
        optimizer.step()

    print('train_loss_{}'.format(index), losses.avg, epoch)
    # log to TensorBoard
    if args.tensorboard:
        log_value('train_loss_{}'.format(index), losses.avg, epoch)
Пример #2
0
def train(trainloader, net, index, optimizer, epoch, use_cuda, logger):
    losses = AverageMeter()

    print('\nIndex: %d \t Epoch: %d' %(index,epoch))

    net.train()

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs = inputs.cuda()
        optimizer.zero_grad()
        inputs_Var = Variable(inputs)
        outputs = net(inputs_Var, index)

        # record loss
        losses.update(outputs.item(), inputs.size(0))

        outputs.backward()
        
        '''
        # gradient clipping for mlp arch
        ch = net.named_parameters()
            for c in ch:
                if 'enc' in c[0]:
                    k = c[1]
                    torch.nn.utils.clip_grad_norm_(c[1], c[1].mean(dtype=float))
                d1 = c[1].view(-1)
                if torch.isnan(d1).any():
                    print('heree')
                    exit(0)
            del (ch)
        '''
        '''
        # gradient clipping for convolutional arch
            ch = net.named_parameters()
            for c in ch:
                #if ('enc' in c[0] and 'benc' not in c[0]) or ('dec' in c[0] and 'bdec' not in c[0]):
                    #print(c[0])
                    #print(c[1])
                    #torch.nn.utils.clip_grad_norm_(c[1], c[1].mean(dtype=float))
                d1 = c[1].view(-1)
                if torch.isnan(d1).any():
                    print('heree')
                    exit(0)
            del (ch)
        '''
        
        optimizer.step()

    # log to TensorBoard
    if logger:
        logger.log_value('train_loss_{}'.format(index), losses.avg, epoch)
Пример #3
0
def test(testloader, net, index, epoch, use_cuda):
    losses = AverageMeter()

    net.eval()

    for batch_idx, (inputs, targets) in enumerate(testloader):
        if use_cuda:
            inputs = inputs.cuda()
        inputs_Var = Variable(inputs, volatile=True)
        outputs = net(inputs_Var, index)

        # measure accuracy and record loss
        losses.update(outputs.data[0], inputs.size(0))

    # log to TensorBoard
    if args.tensorboard:
        log_value('val_loss_{}'.format(index), losses.avg, epoch)
Пример #4
0
def test(testloader, net, index, epoch, use_cuda, logger):
    losses = AverageMeter()

    net.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            if use_cuda:
                inputs = inputs.cuda()
            inputs_Var = Variable(inputs)
            outputs = net(inputs_Var, index)

            # measure accuracy and record loss
            losses.update(outputs.item(), inputs.size(0))

    # log to TensorBoard
    if logger:
        logger.log_value('val_loss_{}'.format(index), losses.avg, epoch)
Пример #5
0
def train(trainloader, net, index, optimizer, epoch, use_cuda, logger):
    losses = AverageMeter()

    print('\nIndex: %d \t Epoch: %d' % (index, epoch))

    net.train()

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs = inputs.cuda()
        optimizer.zero_grad()
        inputs_Var = Variable(inputs)
        outputs = net(inputs_Var, index)

        # record loss
        losses.update(outputs.item(), inputs.size(0))

        outputs.backward()
        optimizer.step()

    # log to TensorBoard
    if logger:
        logger.log_value('train_loss_{}'.format(index), losses.avg, epoch)
Пример #6
0
def train(trainloader, net, optimizer, criterion1, criterion2, epoch, use_cuda, _sigma1, _sigma2, _lambda, logger):
    losses = AverageMeter()
    losses1 = AverageMeter()
    losses2 = AverageMeter()

    print('\n Epoch: %d' % epoch)

    net.train()

    for batch_idx, (inputs, pairweights, sampweights, pairs, index) in enumerate(trainloader):
        inputs = torch.squeeze(inputs,0)
        pairweights = torch.squeeze(pairweights)
        sampweights = torch.squeeze(sampweights)
        index = torch.squeeze(index)
        pairs = pairs.view(-1, 2)

        if use_cuda:
            inputs = inputs.cuda()
            pairweights = pairweights.cuda()
            sampweights = sampweights.cuda()
            index = index.cuda()
            pairs = pairs.cuda()

        optimizer.zero_grad()
        inputs_Var, sampweights, pairweights = Variable(inputs), Variable(sampweights, requires_grad=False), \
                                               Variable(pairweights, requires_grad=False)

        enc, dec = net(inputs_Var)
        loss1 = criterion1(inputs_Var, dec, sampweights)
        loss2 = criterion2(enc, sampweights, pairweights, pairs, index, _sigma1, _sigma2, _lambda)
        loss = loss1 + loss2

        # record loss
        losses1.update(loss1.item(), inputs.size(0))
        losses2.update(loss2.item(), inputs.size(0))
        losses.update(loss.item(), inputs.size(0))

        loss.backward()
        optimizer.step()

    # log to TensorBoard
    if logger:
        logger.log_value('total_loss', losses.avg, epoch)
        logger.log_value('reconstruction_loss', losses1.avg, epoch)
        logger.log_value('dcc_loss', losses2.avg, epoch)
Пример #7
0
def train_one_epoch(dest_aug_mask_perm_dataloader, traj_encoder, dest_proj,
                    aug_proj, mapemb_proj, mask_proj, perm_proj, optimizer,
                    scheduler, criterion_ce, graphregion, config, log_f,
                    log_error):

    traj_encoder.train()
    dest_proj.train()
    aug_proj.train()
    mask_proj.train()
    perm_proj.train()

    losses = AverageMeter()
    losses_dest = AverageMeter()
    losses_aug = AverageMeter()
    losses_mapemb = AverageMeter()
    losses_mask = AverageMeter()
    losses_perm = AverageMeter()

    train_runs = 0
    sample_cnt = 0

    losses_hist = []
    losses_dest_hist = []
    losses_aug_hist = []
    losses_mapemb_hist = []
    losses_mask_hist = []
    losses_perm_hist = []

    while True:

        train_runs += 1

        # re-init loss to zero every iter
        loss = 0.
        try:
            train_batch = next(dest_aug_mask_perm_dataloader)
        except StopIteration as e:
            log_f.write(
                "All dataloader ran out, finishing {}-th epoch's training. \n".
                format(config.epoch))
            print(
                "All dataloader ran out, finishing {}-th epoch's training. \n".
                format(config.epoch))
            break

        if train_batch is None:  # all filtered out: length<10 or [-1]
            train_runs -= 1
            continue

        batch_dest, batch_aug, batch_mask, batch_perm = train_batch

        ####### Destination ###############################################################
        if 'dest' in config.del_tasks:
            loss_destination = None
        else:
            try:
                loss_destination, out_tm, h_t, w_uh_t, negs, neg_term = compute_destination_loss(
                    batch_dest,
                    traj_encoder,
                    dest_proj,
                    graphregion,
                    config,
                )
                #print("loss_destination", loss_destination)
                loss_destination = config.loss_dest_weight * loss_destination
                loss += loss_destination

            except Exception as e:
                traceback.print_exc()
                log_error.write(traceback.format_exc())
                #             print(e)
                loss_destination = None
                if batch_dest is not None:
                    if batch_dest.traj_len.size(
                            0) == 1:  # batchsize = 1, skip the iteration
                        train_runs -= 1
                        continue
                pass

        ####################################################################################
        ####### Augmentation ###############################################################
        if ('aug' in config.del_tasks) and ('mapemb' in config.del_tasks):
            loss_augmentation = None
            loss_mapemb = None
        else:
            try:
                left_aug, right_aug = batch_aug
                is_mapemb = left_aug.traj_len.min() >= 40
                loss_augmentation, loss_mapemb = compute_aug_loss(
                    left_aug, right_aug, traj_encoder, aug_proj, mapemb_proj,
                    graphregion, config, criterion_ce, is_mapemb)
                if ('aug' in config.del_tasks):  # only count loss_mapemb
                    loss_augmentation = None
                    if is_mapemb:  # loss_mapemb is not None
                        loss += loss_mapemb
                if ('mapemb' in config.del_tasks):  # only count loss_aug
                    loss_mapemb = None
                    loss += loss_augmentation

                if ('aug'
                        not in config.del_tasks) and ('mapemb'
                                                      not in config.del_tasks):
                    loss += loss_augmentation
                    if is_mapemb:  # loss_mapemb is not None
                        loss += loss_mapemb
                #print("loss_augmentation", loss_augmentation)

            except Exception as e:
                traceback.print_exc()
                log_error.write(traceback.format_exc())
                #             print(e)
                loss_augmentation = None
                loss_mapemb = None
                pass
        ####################################################################################
        ####### mask #######################################################################
        if 'mask' in config.del_tasks:
            loss_mask = None
        else:
            try:
                loss_mask, batch_queries, h_t, w_uh_t, _neg_term, neg_term = compute_mask_loss(
                    batch_mask,
                    traj_encoder,
                    mask_proj,
                    graphregion,
                    config,
                )
                #print('loss_mask', loss_mask)
                loss_mask = config.loss_mask_weight * loss_mask
                loss += loss_mask
            except Exception as e:
                traceback.print_exc()
                log_error.write(traceback.format_exc())
                loss_mask = None
                pass
        ####################################################################################
        ####### perm #######################################################################
        if 'perm' in config.del_tasks:
            loss_perm = None
        else:
            try:
                anchor, pos, neg = batch_perm
                loss_perm, logits_perm, target_perm = compute_perm_loss(
                    anchor, pos, neg, traj_encoder, perm_proj, graphregion,
                    config, criterion_ce)
                loss_perm = config.loss_perm_weight * loss_perm
                loss += loss_perm
                #print("loss_perm", loss_perm)
            except Exception as e:
                traceback.print_exc()
                log_error.write(traceback.format_exc())
                #             print(e)
                loss_perm = None
                pass
        ####################################################################################

        if (loss_destination is None) and (loss_augmentation is None) and (
                loss_mapemb is None) and (loss_mask is None) and (loss_perm is
                                                                  None):
            if ('dest' in config.del_tasks) and ('perm' in config.del_tasks) and \
            ('aug' in config.del_tasks) and ('mask' in config.del_tasks): # model_with_mapemb
                train_runs -= 1
                continue
            else:
                log_f.write(
                    "All loss none, at {}-th epoch's training: check errordata_e{}_step{}.pkl \n"
                    .format(config.epoch, config.epoch, train_runs))
                print(
                    "All loss none, at {}-th epoch's training: check errordata_e{}_step{}.pkl \n"
                    .format(config.epoch, config.epoch, train_runs))
                pickle.dump((batch_dest, batch_aug, batch_mask, batch_perm),
                            open(
                                'errordata_e{}_step{}.pkl'.format(
                                    config.epoch, train_runs), 'wb'))
                train_runs -= 1
                continue

        sample_cnted = False
        try:
            losses.update(loss.item(), )
        except:
            print(loss, loss_destination, is_mapemb, loss_augmentation,
                  loss_mapemb, loss_mask, loss_perm)
        if loss_destination is not None:
            losses_dest.update(loss_destination.item(), )
            sample_cnt += batch_dest.tm_len.size(0)
            sample_cnted = True

        if loss_augmentation is not None:
            losses_aug.update(loss_augmentation.item(), )
            if not sample_cnted:
                sample_cnt += left_aug.tm_len.size(0)
                sample_cnted = True

        if loss_mapemb is not None:
            losses_mapemb.update(loss_mapemb.item(), )
            if not sample_cnted:
                sample_cnt += left_aug.tm_len.size(0)
                sample_cnted = True
        if loss_mask is not None:
            losses_mask.update(loss_mask.item(), )
            if not sample_cnted:
                sample_cnt += batch_mask.tm_len.size(0)
                sample_cnted = True
        if loss_perm is not None:
            losses_perm.update(loss_perm.item(), )
            if not sample_cnted:
                sample_cnt += pos.tm_len.size(0)
                sample_cnted = True

        if train_runs % 100 == 0:
            #             print("logits_perm, target_perm: ", logits_perm, target_perm)
            #             print('batch_queries', batch_queries)
            #             print('h_t', h_t)
            #             print('w_uh_t',w_uh_t)
            #             print('_neg_term', _neg_term)
            #             print('neg_term', neg_term)

            if 'perm' not in config.del_tasks:
                print("acc : {:.2f}".format(
                    torch.sum(
                        logits_perm.max(1, )[1].cpu() == target_perm.view(
                            -1, ).cpu()).to(torch.float32) /
                    target_perm.size(0)))

            losses_hist.append(losses.val)
            losses_dest_hist.append(losses_dest.val)
            losses_aug_hist.append(losses_aug.val)
            losses_mapemb_hist.append(losses_mapemb.val)
            losses_mask_hist.append(losses_mask.val)
            losses_perm_hist.append(losses_perm.val)

            log_f.write(
                'Train Epoch:{} approx. [{}/{}] total_loss:{:.2f}({:.2f})\n'.
                format(config.epoch, sample_cnt, config.n_trains, losses.val,
                       losses.avg))
            log_f.write(
                'loss_destination:{:.2f}({:.2f}) \nloss_augmentation:{:.2f}({:.2f}) \nloss_mapemb:{:.2f}({:.2f}) \nloss_mask:{:.2f}({:.2f}) \nloss_perm:{:.2f}({:.2f}) \n\n'
                .format(
                    losses_dest.val,
                    losses_dest.avg,
                    losses_aug.val,
                    losses_aug.avg,
                    losses_mapemb.val,
                    losses_mapemb.avg,
                    losses_mask.val,
                    losses_mask.avg,
                    losses_perm.val,
                    losses_perm.avg,
                ))
            print('Train Epoch:{} approx. [{}/{}] total_loss:{:.2f}({:.2f})'.
                  format(config.epoch, sample_cnt, config.n_trains, losses.val,
                         losses.avg))
            print(
                'loss_destination:{:.2f}({:.2f}) \nloss_augmentation:{:.2f}({:.2f}) \nloss_mapemb:{:.2f}({:.2f}) \nloss_mask:{:.2f}({:.2f}) \nloss_perm:{:.2f}({:.2f}) \n'
                .format(
                    losses_dest.val,
                    losses_dest.avg,
                    losses_aug.val,
                    losses_aug.avg,
                    losses_mapemb.val,
                    losses_mapemb.avg,
                    losses_mask.val,
                    losses_mask.avg,
                    losses_perm.val,
                    losses_perm.avg,
                ))
            log_f.flush()
            log_error.flush()

        if train_runs % 4500 == 0:
            log_f.write("At step 4500, save model {}.pt\n".format(
                config.name + '_num_hid_layer_' +
                str(config.num_hidden_layers) +
                '_step{}'.format(train_runs + 1)))
            print("At step 4500, save model {}.pt\n".format(
                config.name + '_num_hid_layer_' +
                str(config.num_hidden_layers) +
                '_step{}'.format(train_runs + 1)))
            ######
            models_dict = {
                traj_encoder.__class__.__name__: traj_encoder.state_dict(),
                mask_proj.__class__.__name__: mask_proj.state_dict(),
                perm_proj.__class__.__name__: perm_proj.state_dict(),
                aug_proj.__class__.__name__: aug_proj.state_dict(),
                mapemb_proj.__class__.__name__: mapemb_proj.state_dict(),
                dest_proj.__class__.__name__: dest_proj.state_dict(),
            }

            torch.save(
                models_dict,
                os.path.join(
                    'models', config.name + '_num_hid_layer_' +
                    str(config.num_hidden_layers) +
                    '_step{}'.format(train_runs) + '.pt'))

            ######

        optimizer.zero_grad()

        loss.backward()

        # every iter
        optimizer.step()

    torch.save((losses_hist, losses_dest_hist, losses_aug_hist,
                losses_mapemb_hist, losses_mask_hist, losses_perm_hist),
               os.path.join('train_hist',
                            config.name+'_loss_hist'+\
                            '_hidlayer_'+str(config.num_hidden_layers)+\
                            'e'+str(config.epoch)+'.pt'))
Пример #8
0
def train_step_2(trainloader, net_s, net_z, net_d, optimizer_zc, optimizer_d, criterion_rec, criterion_zc, criterion_d, epoch, use_cuda, _sigma1, _sigma2, _lambda):

    losses = AverageMeter()
    losses1 = AverageMeter()
    losses2 = AverageMeter()
    losses_d_rec = AverageMeter()
    losses_d = AverageMeter()

    print('\n Epoch: %d' % epoch)

    net_z.train()
    net_d.train()


    decoder_loss = 0.0
    adversarial_loss = 0.0

    for i, (inputs, pairweights, sampweights, pairs, index) in enumerate(trainloader):

        inputs = torch.squeeze(inputs,0)
        pairweights = torch.squeeze(pairweights)
        sampweights = torch.squeeze(sampweights)
        index = torch.squeeze(index)
        pairs = pairs.view(-1, 2)

        if use_cuda:
            inputs = inputs.cuda()
            pairweights = pairweights.cuda()
            sampweights = sampweights.cuda()
            index = index.cuda()
            pairs = pairs.cuda()

        inputs, sampweights, pairweights = Variable(inputs), Variable(sampweights, requires_grad=False), \
            Variable(pairweights, requires_grad=False)


        # train z encoder and decoder
        if i % 3 == 0:
            # zero the parameter gradients
            optimizer_d.zero_grad()
            optimizer_zc.zero_grad()
            # forward + backward + optimize

            outputs_s, _ = net_s(inputs)
            outputs_z, dec_z = net_z(inputs)

            loss1 = criterion_rec(inputs, dec_z, sampweights)
            loss2 = criterion_zc(outputs_z, sampweights, pairweights, pairs, index, _sigma1, _sigma2, _lambda)
            loss_zc = loss1 + loss2

            # record loss
            losses1.update(loss1.data[0], inputs.size(0))
            losses2.update(loss2.data[0], inputs.size(0))
            losses.update(loss_zc.data[0], inputs.size(0))

            decoder_input = torch.cat((outputs_s, outputs_z),1)

            outputs_d = net_d(decoder_input)
            #beta = 1.985 # change?
            beta = 1.99 # change?
            loss_d_rec = criterion_d(outputs_d, inputs)
            loss_d =  loss_d_rec - beta * loss_zc

            #record loss
            losses_d_rec.update(loss_d_rec.data[0], inputs.size(0))
            losses_d.update(loss_d.data[0], inputs.size(0))

            loss_d.backward()
            #loss_zc.backward()
            optimizer_d.step()
            optimizer_zc.step()
            decoder_loss += loss_d.data[0]

            print('dcc_reconstruction_loss', losses1.avg, epoch)
            print('dcc_clustering_loss', losses2.avg, epoch)
            print('dcc_loss', losses.avg, epoch)
            print('total_reconstruction_loss', losses_d_rec.avg, epoch)
            print('total_loss', losses_d.avg, epoch)
            # log to TensorBoard
            if args.tensorboard:
                log_value('dcc_reconstruction_loss', losses1.avg, epoch)
                log_value('dcc_clustering_loss', losses2.avg, epoch)
                log_value('dcc_loss', losses.avg, epoch)
                log_value('total_reconstruction_loss', losses_d_rec.avg, epoch)
                log_value('total_loss', losses_d.avg, epoch)

        # train adversarial clustering
        else:
            # zero the parameter gradients
            optimizer_zc.zero_grad()
            # forward + backward + optimize
            outputs_z, dec_z = net_z(inputs)

            loss1 = criterion_rec(inputs, dec_z, sampweights)
            loss2 = criterion_zc(outputs_z, sampweights, pairweights, pairs, index, _sigma1, _sigma2, _lambda)
            loss_zc = loss1 + loss2

            # record loss
            losses1.update(loss1.data[0], inputs.size(0))
            losses2.update(loss2.data[0], inputs.size(0))
            losses.update(loss_zc.data[0], inputs.size(0))

            loss_zc.backward()
            optimizer_zc.step()
            adversarial_loss += loss_zc.data[0]


        # print statistics
        if i % 2000 == 1999:  # print every 2000 mini-batches
            print('[%d, %5d] decoder loss: %.3f, adversarial loss: %.3f' %(epoch + 1, i + 1, decoder_loss / 500, adversarial_loss / 1500))
            decoder_loss = 0.0
            adversarial_loss = 0.0