def train(train_loader, args, model, criterion, D_criterion, center_loss, optimizer, optimizer1, epoch, num_epochs, mymodel, discriminator): since = time.time() running_loss0 = AverageMeter() running_loss1 = AverageMeter() running_loss2 = AverageMeter() running_loss3 = AverageMeter() running_loss4 = AverageMeter() running_loss5 = AverageMeter() running_loss6 = AverageMeter() running_loss = AverageMeter() running_myloss = AverageMeter() running_lossC = AverageMeter() log = Log() model.train() mymodel.train() discriminator.train() img_onehot = torch.zeros(1, 4) vid_onehot = torch.zeros(1, 4) aud_onehot = torch.zeros(1, 4) txt_onehot = torch.zeros(1, 4) img_onehot[0][0] = 1 vid_onehot[0][1] = 1 aud_onehot[0][2] = 1 txt_onehot[0][3] = 1 sumx = 0 sumy = 0 for (i, (input, input1, input2, input3, target)) in enumerate(train_loader): input_var = Variable(input.cuda()) input_var1 = Variable(input1.cuda()) input_var2 = Variable(input2.cuda()) input_var3 = Variable(input3.cuda()) target_var = Variable(target.cuda()) target_var1 = Variable(target.cuda()) target_var2 = Variable(target.cuda()) target_var3 = Variable(target.cuda()) outputs = model(input_var, input_var1, input_var2) myloss, mytxt = mymodel.loss(input_var3, target_var3) size = int(outputs.size(0) / 3) img = outputs.narrow(0, 0, size) vid = outputs.narrow(0, size, size) aud = outputs.narrow(0, 2 * size, size) loss0 = criterion(img, target_var) loss1 = criterion(vid, target_var1) loss2 = criterion(aud, target_var2) loss4 = loss0 + loss1 + loss2 + myloss if (args.loss_choose == 'r'): loss6, _ = ranking_loss(targets, outputs, margin=1, margin2=0.5, squared=False) loss6 = loss6 * 0.1 else: loss6 = 0.0 loss = loss4 + loss6 mysize1, mysize2 = img.size() real_label = torch.ones(mysize1) lossC = D_criterion(torch.sum(discriminator(img) * img_onehot.repeat(mysize1, 1).cuda(), dim=1),real_label.cuda()) \ + D_criterion(torch.sum(discriminator(vid) * vid_onehot.repeat(mysize1, 1).cuda(), dim=1), real_label.cuda()) \ + D_criterion(torch.sum(discriminator(aud) * aud_onehot.repeat(mysize1, 1).cuda(), dim=1), real_label.cuda()) \ + D_criterion(torch.sum(discriminator(mytxt) * txt_onehot.repeat(mysize1, 1).cuda(), dim=1), real_label.cuda()) g_loss = loss - lossC d_loss = -(loss - lossC) batchsize = input_var.size(0) running_loss0.update(loss0.item(), batchsize) running_loss1.update(loss1.item(), batchsize) running_loss2.update(loss2.item(), batchsize) running_loss4.update(loss4.item(), batchsize) if (args.loss_choose == 'r'): running_loss6.update(loss6.item(), batchsize) running_loss.update(loss.item(), batchsize) running_myloss.update(myloss.item(), batchsize) running_lossC.update(lossC.item(), batchsize) optimizer.zero_grad() g_loss.backward(retain_graph=True) optimizer.step() optimizer1.zero_grad() d_loss.backward() optimizer1.step() sumx += mymodel.loss_n_acc(input_var3, target_var3)[1] sumy += input_var3.size()[0] mytext_acc = sumx / sumy if (i % args.print_freq == 0): print('-' * 20) print('Epoch [{0}/{1}][{2}/{3}]'.format(epoch, num_epochs, i, len(train_loader))) print('Image Loss: {loss.avg:.5f}'.format(loss=running_loss0)) print('Video Loss: {loss.avg:.5f}'.format(loss=running_loss1)) print('Audio Loss: {loss.avg:.5f}'.format(loss=running_loss2)) print('AllMedia Loss: {loss.avg:.5f}'.format(loss=running_loss4)) print('lstm+selfattention Loss: {loss.avg:.5f}'.format( loss=running_myloss)) print('Discriminator Loss: {loss.avg:.5f}'.format( loss=running_lossC)) if (args.loss_choose == 'r'): print( 'Ranking Loss: {loss.avg:.5f}'.format(loss=running_loss6)) print('All Loss: {loss.avg:.5f}'.format(loss=running_loss)) print("Text train Acc:", mytext_acc) log.save_train_info(epoch, i, len(train_loader), running_loss) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60))
def train(train_loader, train_loader1, train_loader2, train_loader3, args, model, criterion, center_loss, optimizer, epoch, num_epochs): since = time.time() running_loss0 = AverageMeter() running_loss1 = AverageMeter() running_loss2 = AverageMeter() running_loss3 = AverageMeter() running_loss4 = AverageMeter() running_loss5 = AverageMeter() running_loss6 = AverageMeter() running_loss = AverageMeter() log = Log() model.train() for (i, (input, target)),(j, (input1, target1)),(k, (input2, target2)),(p, (input3, target3)) in zip(enumerate(train_loader),enumerate(train_loader1),enumerate(train_loader2),enumerate(train_loader3)): input_var = Variable(input.cuda()) input_var1 = Variable(input1.cuda()) input_var2 = Variable(input2.cuda()) input_var3 = Variable(input3.cuda()) targets = torch.cat((target,target1,target2,target3),0) targets = Variable(targets.cuda()) target_var = Variable(target.cuda()) target_var1 = Variable(target1.cuda()) target_var2 = Variable(target2.cuda()) target_var3 = Variable(target3.cuda()) outputs = model(input_var,input_var1,input_var2,input_var3) size = int(outputs.size(0)/4) img = outputs.narrow(0, 0, size) vid = outputs.narrow(0, size, size) aud = outputs.narrow(0, 2*size, size) txt = outputs.narrow(0, 3*size, size) loss0 = criterion(img, target_var) loss1 = criterion(vid, target_var1) loss2 = criterion(aud, target_var2) loss3 = criterion(txt, target_var3) loss4 = loss0 + loss1 + loss2 + loss3 loss5 = center_loss(outputs,targets)*0.001 if(args.loss_choose == 'r'): loss6, _ = ranking_loss(targets, outputs, margin=1, margin2=0.5, squared=False) loss6 = loss6 * 0.1 else: loss6 = 0.0 loss = loss4 + loss5 + loss6 batchsize = input_var.size(0) running_loss0.update(loss0.item(), batchsize) running_loss1.update(loss1.item(), batchsize) running_loss2.update(loss2.item(), batchsize) running_loss3.update(loss3.item(), batchsize) running_loss4.update(loss4.item(), batchsize) running_loss5.update(loss5.item(), batchsize) if(args.loss_choose == 'r'): running_loss6.update(loss6.item(), batchsize) running_loss.update(loss.item(), batchsize) optimizer.zero_grad() loss.backward() for param in center_loss.parameters(): param.grad.data *= (1./0.001) optimizer.step() if (i % args.print_freq == 0): print('-' * 20) print('Epoch [{0}/{1}][{2}/{3}]'.format(epoch, num_epochs, i, len(train_loader))) print('Image Loss: {loss.avg:.5f}'.format(loss=running_loss0)) print('Video Loss: {loss.avg:.5f}'.format(loss=running_loss1)) print('Audio Loss: {loss.avg:.5f}'.format(loss=running_loss2)) print('Text Loss: {loss.avg:.5f}'.format(loss=running_loss3)) print('AllMedia Loss: {loss.avg:.5f}'.format(loss=running_loss4)) print('Center Loss: {loss.avg:.5f}'.format(loss=running_loss5)) if(args.loss_choose == 'r'): print('Ranking Loss: {loss.avg:.5f}'.format(loss=running_loss6)) print('All Loss: {loss.avg:.5f}'.format(loss=running_loss)) log.save_train_info(epoch, i, len(train_loader), running_loss) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))