def validate(val_loader, model, criterion, device, model_h=None, mode=None): losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') # keep predicted results and gts for calculate F1 Score gts = [] preds = [] # switch to evaluate mode model.eval() with torch.no_grad(): for i, sample in enumerate(val_loader): x = sample[0] if mode == "d": t = sample[2] else: t = sample[1] x = x.to(device) t = t.to(device) batch_size = x.shape[0] # compute output and loss if model_h is not None: model_h.eval() output = model_h(model(x)) else: output = model(x) loss = criterion(output, t) # measure accuracy and record loss acc1 = accuracy(output, t, topk=(1,)) losses.update(math.sqrt(loss.item())*100, batch_size) top1.update(acc1[0].item(), batch_size) # keep predicted results and gts for calculate F1 Score _, pred = output.max(dim=1) gts += list(t.to("cpu").numpy()) preds += list(pred.to("cpu").numpy()) f1s = f1_score(gts, preds, average="macro") return losses.avg, top1.avg, f1s
def train(train_loader, model, criterion, optimizer, epoch, device): # 平均を計算してくれるクラス batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') # 進捗状況を表示してくれるクラス progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses, top1], prefix="Epoch: [{}]".format(epoch)) # keep predicted results and gts for calculate F1 Score gts = [] preds = [] # switch to train mode model.train() end = time.time() for i, sample in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) x = sample[0] t = sample[1] x = x.to(device) t = t.to(device) batch_size = x.shape[0] # compute output and loss output = model(x) loss = criterion(output, t) # measure accuracy and record loss acc1 = accuracy(output, t, topk=(1, )) losses.update(loss.item(), batch_size) top1.update(acc1[0].item(), batch_size) # keep predicted results and gts for calculate F1 Score _, pred = output.max(dim=1) gts += list(t.to("cpu").numpy()) preds += list(pred.to("cpu").numpy()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # show progress bar per 50 iteration if i != 0 and i % 50 == 0: progress.display(i) # calculate F1 Score f1s = f1_score(gts, preds, average="macro") return losses.avg, top1.avg, f1s
def train_adv(train_loader, model_g, model_h, model_d, criterion, optimizer_gh, optimizer_d, epoch, device, beta=1): model_g.to(device) model_h.to(device) model_d.to(device) # average meter batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses_g = AverageMeter('Loss_g', ':.4e') top1_g = AverageMeter('Acc@1_g', ':6.2f') losses_d = AverageMeter('Loss_d', ':.4e') top1_d = AverageMeter('Acc@1_d', ':6.2f') # progress meter progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses_g, top1_g, losses_d, top1_d], prefix="Epoch: [{}]".format(epoch) ) # keep predicted results and gts for calculate F1 Score gts = [] preds = [] end = time.time() for i, sample in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) x = sample[0] t = sample[1] g = sample[2] x = x.to(device) t = t.to(device) g = g.to(device) batch_size = x.shape[0] # train discriminator model_g.eval() model_d.train() feat = model_g(x) output = model_d(feat) loss_d = criterion(output, g) optimizer_d.zero_grad() loss_d.backward() optimizer_d.step() acc1 = accuracy(output, g, topk=(1,)) losses_d.update(loss_d.item(), batch_size) top1_d.update(acc1[0].item(), batch_size) # train generator model_g.train() model_h.train() model_d.eval() feat = model_g(x) output = model_h(feat) loss_g = criterion(output, t) # decieve discriminator output_d = model_d(model_g(x)) adv_g = torch.LongTensor([0 if i == 1 else 1 for i in g]).to(device) loss_adv_g = criterion(output_d, adv_g) loss_g = loss_g + beta*loss_adv_g optimizer_gh.zero_grad() loss_g.backward() optimizer_gh.step() acc1 = accuracy(output, t, topk=(1,)) losses_g.update(loss_g.item(), batch_size) top1_g.update(acc1[0].item(), batch_size) # keep predicted results and gts for calculate F1 Score _, pred = output.max(dim=1) gts += list(t.to("cpu").numpy()) preds += list(pred.to("cpu").numpy()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # show progress bar per 50 iteration if i != 0 and i % 50 == 0: progress.display(i) # calculate F1 Score f1s = f1_score(gts, preds, average="macro") return losses_g.avg, losses_d.avg, top1_g.avg, top1_d.avg, f1s