Пример #1
0
def train_gcr(model, criterion, optimizer, optimizer_cnn, trainloader, device,
              epoch, log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss1 = AverageMeter()
    avg_loss2 = AverageMeter()
    avg_acc1 = AverageMeter()
    avg_acc2 = AverageMeter()
    # Create recorder
    averagers = [avg_loss1, avg_loss2, avg_acc1, avg_acc2]
    names = ['train loss1', 'train loss2', 'train acc1', 'train acc2']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs and labels
        data, lab = [_.to(device) for _ in batch]

        # forward
        p = args.shot * args.train_way
        data_shot = data[:p]
        data_query = data[p:]

        logits, label, logits2, gt = \
                model(data_shot,data_query,lab)
        # compute the loss
        loss, loss1, loss2 = criterion(logits, label, logits2, gt)

        # backward & optimize
        optimizer.zero_grad()
        optimizer_cnn.zero_grad()
        loss.backward()
        if epoch > 45:
            optimizer_cnn.step()
        optimizer.step()

        # compute the metrics
        acc1 = accuracy(logits, label)[0]
        acc2 = accuracy(logits2, gt)[0]

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss1.item(), loss2.item(), acc1, acc2]
        recoder.update(vals)

        if i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Пример #2
0
def train_one_epoch(model, trainloader, device, epoch, log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_ltotal = AverageMeter()
    avg_lrec = AverageMeter()
    avg_lltc = AverageMeter()
    avg_lee = AverageMeter()
    avg_ladv_g = AverageMeter()
    avg_ladv_d = AverageMeter()
    # Set trainning mode
    model.train()
    # Create recorder
    averagers = [
        avg_ltotal, avg_lrec, avg_lltc, avg_lee, avg_ladv_g, avg_ladv_d
    ]
    names = [
        'train Ltotal', 'train Lrec', 'train Lltc', 'train Lee',
        'train Ladv G', 'train Ladv D'
    ]
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    recoder.tik()
    recoder.data_tik()
    for i, data in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs
        Qh, Ph, glove_angles, group_names = data
        Qh, Ph, glove_angles = [x.to(device) for x in (Qh, Ph, glove_angles)]

        # optimize parameters
        losses = model.optimize_parameters(Qh, Ph)
        Ltotal, Lrec, Lltc, Lee, Ladv_G, Ladv_D = losses

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [
            Ltotal.item(),
            Lrec.item(),
            Lltc.item(),
            Lee.item(),
            Ladv_G.item(),
            Ladv_D.item()
        ]
        N = Qh.size(0)
        recoder.update(vals, count=N)

        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Пример #3
0
def train_mn_pn(model, criterion, optimizer, trainloader, device, epoch,
                log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc]
    names = ['train loss', 'train acc']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs and labels
        data, lab = [_.to(device) for _ in batch]

        # forward
        p = args.shot * args.train_way
        data_shot = data[:p]
        data_query = data[p:]

        y_pred, label = model(data_shot, data_query)
        # print('lab: {}'.format(lab.view((args.shot+args.query),args.train_way)[0]))
        # compute the loss
        loss = criterion(y_pred, label)
        # print('y_pred: {}'.format(y_pred))
        # print('label: {}'.format(label))

        # backward & optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # compute the metrics
        acc = accuracy(y_pred, label)[0]

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), acc]
        recoder.update(vals)

        if i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Пример #4
0
def train_c3d(model, criterion, optimizer, trainloader, device, epoch,
              log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    avg_top1 = AverageMeter()
    avg_top5 = AverageMeter()
    # Create recorder
    averagers = [losses, avg_top1, avg_top5]
    names = ['train loss', 'train top1', 'train top5']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the data and labels
        data, lab = [_.to(device) for _ in batch]

        optimizer.zero_grad()
        # forward
        outputs = model(data)

        # compute the loss
        loss = criterion(outputs, lab)

        # backward & optimize
        loss.backward()
        optimizer.step()

        # compute the metrics
        top1, top5 = accuracy(outputs, lab, topk=(1, 5))

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), top1, top5]
        recoder.update(vals)

        # logging
        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Пример #5
0
def train_text2sign(model, criterion, optimizer, trainloader, device, epoch,
                    log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    # Set trainning mode
    model.train()
    # Create recorder
    averagers = [avg_loss]
    names = ['train loss']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    recoder.tik()
    recoder.data_tik()
    for i, data in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs and labels
        # shape of input is N x T
        # shape of tgt is N x T2 x J x D
        input, tgt = data['input'].to(device), data['tgt'].to(device)

        optimizer.zero_grad()
        # forward
        outputs = model(input, tgt)

        # compute the loss
        # tgt = pack_padded_sequence(tgt,tgt_len_list)
        loss = criterion(outputs, tgt[:, 1:, :, :])
        # backward & optimize
        loss.backward()

        optimizer.step()

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item()]
        N = input.size(0)
        recoder.update(vals, count=N)

        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Пример #6
0
def train_one_epoch(model, criterion, optimizer, trainloader, device, epoch,
                    log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    # Set trainning mode
    model.train()
    # Create recorder
    averagers = [avg_loss]
    names = ['train loss']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)

    recoder.tik()
    recoder.data_tik()
    for i, data in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs
        q, p = [x.to(device) for x in data]

        optimizer.zero_grad()
        # forward
        outputs = model(q)

        # compute the loss
        loss = criterion(outputs, q)
        # backward & optimize
        loss.backward()

        optimizer.step()

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item()]
        N = q.size(0)
        recoder.update(vals, count=N)

        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()
Пример #7
0
def train_seq2seq(model, criterion, optimizer, clip, dataloader, device, epoch,
                  log_interval, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    avg_wer = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc, avg_wer]
    names = ['train loss', 'train acc', 'train wer']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for batch_idx, batch in enumerate(dataloader):
        # measure data loading time
        recoder.data_tok()
        # get the data and labels
        imgs = batch['videos'].cuda()
        target = batch['annotations'].permute(1, 0).contiguous().cuda()

        optimizer.zero_grad()
        # forward
        outputs = model(imgs, target)

        # target: (batch_size, trg len)
        # outputs: (trg_len, batch_size, output_dim)
        # skip sos
        output_dim = outputs.shape[-1]
        outputs = outputs[1:].view(-1, output_dim)
        target = target.permute(1, 0)[1:].reshape(-1)

        # compute the loss
        loss = criterion(outputs, target)

        # compute the accuracy
        prediction = torch.max(outputs, 1)[1]
        score = accuracy_score(target.cpu().data.squeeze().numpy(),
                               prediction.cpu().data.squeeze().numpy())

        # compute wer
        # prediction: ((trg_len-1)*batch_size)
        # target: ((trg_len-1)*batch_size)
        batch_size = imgs.shape[0]
        prediction = prediction.view(-1, batch_size).permute(1, 0).tolist()
        target = target.view(-1, batch_size).permute(1, 0).tolist()
        wers = []
        for i in range(batch_size):
            # add mask(remove padding, sos, eos)
            prediction[i] = [
                item for item in prediction[i] if item not in [0, 1, 2]
            ]
            target[i] = [item for item in target[i] if item not in [0, 1, 2]]
            wers.append(wer(target[i], prediction[i]))
        batch_wer = sum(wers) / len(wers)

        # backward & optimize
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), score, batch_wer]
        b = imgs.size(0)
        recoder.update(vals, count=b)

        if batch_idx == 0 or (batch_idx + 1) % log_interval == 0:
            recoder.log(epoch, batch_idx, len(dataloader))
            # Reset average meters
            recoder.reset()
Пример #8
0
def train_maml(model, criterion, optimizer, trainloader, device, epoch,
               log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    # Create recorder
    averagers = [avg_loss, avg_acc]
    names = ['train loss', 'train acc']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    # Settings
    create_graph = (True if args.order == 2 else False)
    task_gradients = []
    task_losses = []
    for i, batch in enumerate(trainloader):
        # measure data loading time
        recoder.data_tok()

        # get the inputs and labels
        data, lab = [_.to(device) for _ in batch]

        # forward
        # data = data.view( ((args.shot+args.query),args.train_way) + data.size()[-3:] )
        # data = data.permute(1,0,2,3,4).contiguous()
        # data = data.view( (-1,) + data.size()[-3:] )
        p = args.shot * args.train_way
        data_shot = data[:p]
        data_query = data[p:]
        data_shape = data_shot.size()[-3:]

        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())

        # Train the model for `inner_train_steps` iterations
        for inner_batch in range(args.inner_train_steps):
            # Perform update of model weights
            y = create_nshot_task_label(args.train_way, args.shot).to(device)
            logits = model.functional_forward(data_shot, fast_weights)
            loss = criterion(logits, y)
            gradients = torch.autograd.grad(loss,
                                            fast_weights.values(),
                                            create_graph=create_graph)

            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - args.inner_lr * grad)
                for ((name, param),
                     grad) in zip(fast_weights.items(), gradients))

        # Do a pass of the model on the validation data from the current task
        y = create_nshot_task_label(args.train_way, args.query).to(device)
        logits = model.functional_forward(data_query, fast_weights)
        loss = criterion(logits, y)
        loss.backward(retain_graph=True)

        # Get post-update accuracies
        y_pred = logits.softmax(-1)
        acc = accuracy(y_pred, y)[0]

        # Accumulate losses and gradients
        task_losses.append(loss)
        gradients = torch.autograd.grad(loss,
                                        fast_weights.values(),
                                        create_graph=create_graph)
        named_grads = {
            name: g
            for ((name, _), g) in zip(fast_weights.items(), gradients)
        }
        task_gradients.append(named_grads)

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), acc]
        recoder.update(vals)

        if i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()

    if args.order == 1:
        sum_task_gradients = {
            k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
            for k in task_gradients[0].keys()
        }
        hooks = []
        for name, param in model.named_parameters():
            hooks.append(
                param.register_hook(replace_grad(sum_task_gradients, name)))

        model.train()
        optimizer.zero_grad()
        # Dummy pass in order to create `loss` variable
        # Replace dummy gradients with mean task gradients using hooks
        logits = model(
            torch.zeros((args.train_way, ) + data_shape).to(device,
                                                            dtype=torch.float))
        loss = criterion(logits,
                         create_nshot_task_label(args.train_way, 1).to(device))
        loss.backward()
        optimizer.step()

        for h in hooks:
            h.remove()

    elif args.order == 2:
        model.train()
        optimizer.zero_grad()
        meta_batch_loss = torch.stack(task_losses).mean()
        meta_batch_loss.backward()
        optimizer.step()
Пример #9
0
def train_cnn(model, criterion, optimizer, trainloader, device, epoch,
              log_interval, writer, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    avg_acc = AverageMeter()
    global_proto = numpy.zeros([args.num_class, args.feature_dim])
    # Create recorder
    averagers = [losses, avg_acc]
    names = ['train loss', 'train acc']
    recoder = Recorder(averagers, names, writer, batch_time, data_time)
    # Set trainning mode
    model.train()

    recoder.tik()
    recoder.data_tik()
    for i, batch in enumerate(trainloader, 1):
        # measure data loading time
        recoder.data_tok()

        # get the data and labels
        data, lab = [_.to(device) for _ in batch]

        optimizer.zero_grad()
        # forward
        outputs = model(data)

        # compute the loss
        loss = criterion(outputs, lab)

        # backward & optimize
        loss.backward()
        optimizer.step()

        # Account global proto
        proto = model.get_feature(data)
        for idx, p in enumerate(proto):
            p = p.data.detach().cpu().numpy()
            c = lab[idx]
            global_proto[c] += p
        # compute the metrics
        acc = accuracy(outputs, lab)[0]

        # measure elapsed time
        recoder.tok()
        recoder.tik()
        recoder.data_tik()

        # update average value
        vals = [loss.item(), acc]
        recoder.update(vals)

        # logging
        if i == 0 or i % log_interval == log_interval - 1:
            recoder.log(epoch, i, len(trainloader))
            # Reset average meters
            recoder.reset()

    global_proto[:args.n_base] = global_proto[:args.n_base] / args.n_reserve
    global_proto[args.n_base:] = global_proto[args.n_base:] / args.shot
    return global_proto