コード例 #1
0
def test(
        data,
        weights=None,
        batch_size=32,
        imgsz=640,
        conf_thres=0.001,
        iou_thres=0.6,  # for NMS
        save_json=False,
        single_cls=False,
        augment=False,
        verbose=False,
        model=None,
        dataloader=None,
        save_dir=Path(''),  # for saving images
        save_txt=False,  # for auto-labelling
        save_hybrid=False,  # for hybrid auto-labelling
        save_conf=False,  # save auto-label confidences
        plots=True,
        log_imgs=0):  # number of logged images

    # Initialize/load model and set device
    training = model is not None
    if training:  # called by train.py
        device = next(model.parameters()).device  # get model device

    else:  # called directly
        set_logging()
        device = select_device(opt.device, batch_size=batch_size)

        # Directories
        save_dir = Path(
            increment_path(Path(opt.project) / opt.name,
                           exist_ok=opt.exist_ok))  # increment run
        (save_dir / 'labels' if save_txt else save_dir).mkdir(
            parents=True, exist_ok=True)  # make dir

        # Load model
        model = attempt_load(weights, map_location=device)  # load FP32 model
        imgsz = check_img_size(imgsz, s=model.stride.max())  # check img_size

        # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
        # if device.type != 'cpu' and torch.cuda.device_count() > 1:
        #     model = nn.DataParallel(model)

    # Half
    half = device.type != 'cpu'  # half precision only supported on CUDA
    if half:
        model.half()

    # Configure
    model.eval()
    is_coco = data.endswith('coco.yaml')  # is COCO dataset
    with open(data) as f:
        data = yaml.load(f, Loader=yaml.FullLoader)  # model dict
    check_dataset(data)  # check
    nc = 1 if single_cls else int(data['nc'])  # number of classes
    iouv = torch.linspace(0.5, 0.95,
                          10).to(device)  # iou vector for [email protected]:0.95
    niou = iouv.numel()

    # Logging
    log_imgs, wandb = min(log_imgs, 100), None  # ceil
    try:
        import wandb  # Weights & Biases
    except ImportError:
        log_imgs = 0

    # Dataloader
    if not training:
        img = torch.zeros((1, 3, imgsz, imgsz), device=device)  # init img
        _ = model(img.half() if half else img
                  ) if device.type != 'cpu' else None  # run once
        path = data['test'] if opt.task == 'test' else data[
            'val']  # path to val/test images
        dataloader = create_dataloader(path,
                                       imgsz,
                                       batch_size,
                                       model.stride.max(),
                                       opt,
                                       pad=0.5,
                                       rect=True)[0]

    seen = 0
    confusion_matrix = ConfusionMatrix(nc=nc)
    names = {
        k: v
        for k, v in enumerate(
            model.names if hasattr(model, 'names') else model.module.names)
    }
    coco91class = coco80_to_coco91_class()
    s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R',
                                 '[email protected]', '[email protected]:.95')
    p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
    loss = torch.zeros(3, device=device)
    jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
    for batch_i, (img, targets, paths,
                  shapes) in enumerate(tqdm(dataloader, desc=s)):
        img = img.to(device, non_blocking=True)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        targets = targets.to(device)
        nb, _, height, width = img.shape  # batch size, channels, height, width

        with torch.no_grad():
            # Run model
            t = time_synchronized()
            inf_out, train_out = model(
                img, augment=augment)  # inference and training outputs
            t0 += time_synchronized() - t

            # Compute loss
            if training:
                loss += compute_loss([x.float() for x in train_out], targets,
                                     model)[1][:3]  # box, obj, cls

            # Run NMS
            targets[:, 2:] *= torch.Tensor([width, height, width,
                                            height]).to(device)  # to pixels
            lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)
                  ] if save_hybrid else []  # for autolabelling
            t = time_synchronized()
            output = non_max_suppression(inf_out,
                                         conf_thres=conf_thres,
                                         iou_thres=iou_thres,
                                         labels=lb)
            t1 += time_synchronized() - t

        # Statistics per image
        for si, pred in enumerate(output):
            labels = targets[targets[:, 0] == si, 1:]
            nl = len(labels)
            tcls = labels[:, 0].tolist() if nl else []  # target class
            path = Path(paths[si])
            seen += 1

            if len(pred) == 0:
                if nl:
                    stats.append((torch.zeros(0, niou, dtype=torch.bool),
                                  torch.Tensor(), torch.Tensor(), tcls))
                continue

            # Predictions
            predn = pred.clone()
            scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0],
                         shapes[si][1])  # native-space pred

            # Append to text file
            if save_txt:
                gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0
                                                  ]]  # normalization gain whwh
                for *xyxy, conf, cls in predn.tolist():
                    xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) /
                            gn).view(-1).tolist()  # normalized xywh
                    line = (cls, *xywh,
                            conf) if save_conf else (cls,
                                                     *xywh)  # label format
                    with open(save_dir / 'labels' / (path.stem + '.txt'),
                              'a') as f:
                        f.write(('%g ' * len(line)).rstrip() % line + '\n')

            # W&B logging
            if plots and len(wandb_images) < log_imgs:
                box_data = [{
                    "position": {
                        "minX": xyxy[0],
                        "minY": xyxy[1],
                        "maxX": xyxy[2],
                        "maxY": xyxy[3]
                    },
                    "class_id": int(cls),
                    "box_caption": "%s %.3f" % (names[cls], conf),
                    "scores": {
                        "class_score": conf
                    },
                    "domain": "pixel"
                } for *xyxy, conf, cls in pred.tolist()]
                boxes = {
                    "predictions": {
                        "box_data": box_data,
                        "class_labels": names
                    }
                }  # inference-space
                wandb_images.append(
                    wandb.Image(img[si], boxes=boxes, caption=path.name))

            # Append to pycocotools JSON dictionary
            if save_json:
                # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
                image_id = int(
                    path.stem) if path.stem.isnumeric() else path.stem
                box = xyxy2xywh(predn[:, :4])  # xywh
                box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
                for p, b in zip(pred.tolist(), box.tolist()):
                    jdict.append({
                        'image_id':
                        image_id,
                        'category_id':
                        coco91class[int(p[5])] if is_coco else int(p[5]),
                        'bbox': [round(x, 3) for x in b],
                        'score':
                        round(p[4], 5)
                    })

            # Assign all predictions as incorrect
            correct = torch.zeros(pred.shape[0],
                                  niou,
                                  dtype=torch.bool,
                                  device=device)
            if nl:
                detected = []  # target indices
                tcls_tensor = labels[:, 0]

                # target boxes
                tbox = xywh2xyxy(labels[:, 1:5])
                scale_coords(img[si].shape[1:], tbox, shapes[si][0],
                             shapes[si][1])  # native-space labels
                if plots:
                    confusion_matrix.process_batch(
                        pred, torch.cat((labels[:, 0:1], tbox), 1))

                # Per target class
                for cls in torch.unique(tcls_tensor):
                    ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(
                        -1)  # prediction indices
                    pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(
                        -1)  # target indices

                    # Search for detections
                    if pi.shape[0]:
                        # Prediction to target ious
                        ious, i = box_iou(predn[pi, :4], tbox[ti]).max(
                            1)  # best ious, indices

                        # Append detections
                        detected_set = set()
                        for j in (ious > iouv[0]).nonzero(as_tuple=False):
                            d = ti[i[j]]  # detected target
                            if d.item() not in detected_set:
                                detected_set.add(d.item())
                                detected.append(d)
                                correct[
                                    pi[j]] = ious[j] > iouv  # iou_thres is 1xn
                                if len(
                                        detected
                                ) == nl:  # all targets already located in image
                                    break

            # Append statistics (correct, conf, pcls, tcls)
            stats.append(
                (correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))

        # Plot images
        if plots and batch_i < 3:
            f = save_dir / f'test_batch{batch_i}_labels.jpg'  # labels
            Thread(target=plot_images,
                   args=(img, targets, paths, f, names),
                   daemon=True).start()
            f = save_dir / f'test_batch{batch_i}_pred.jpg'  # predictions
            Thread(target=plot_images,
                   args=(img, output_to_target(output), paths, f, names),
                   daemon=True).start()

    # Compute statistics
    stats = [np.concatenate(x, 0) for x in zip(*stats)]  # to numpy
    if len(stats) and stats[0].any():
        p, r, ap, f1, ap_class = ap_per_class(*stats,
                                              plot=plots,
                                              save_dir=save_dir,
                                              names=names)
        p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(
            1)  # [P, R, [email protected], [email protected]:0.95]
        mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
        nt = np.bincount(stats[3].astype(np.int64),
                         minlength=nc)  # number of targets per class
    else:
        nt = torch.zeros(1)

    # Print results
    pf = '%20s' + '%12.3g' * 6  # print format
    print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))

    # Print results per class
    if verbose and nc > 1 and len(stats):
        for i, c in enumerate(ap_class):
            print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))

    # Print speeds
    t = tuple(x / seen * 1E3
              for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size)  # tuple
    if not training:
        print(
            'Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g'
            % t)

    # Plots
    if plots:
        confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
        if wandb and wandb.run:
            wandb.log({"Images": wandb_images})
            wandb.log({
                "Validation": [
                    wandb.Image(str(f), caption=f.name)
                    for f in sorted(save_dir.glob('test*.jpg'))
                ]
            })

    # Save JSON
    if save_json and len(jdict):
        w = Path(weights[0] if isinstance(weights, list) else weights
                 ).stem if weights is not None else ''  # weights
        anno_json = '../coco/annotations/instances_val2017.json'  # annotations json
        pred_json = str(save_dir / f"{w}_predictions.json")  # predictions json
        print('\nEvaluating pycocotools mAP... saving %s...' % pred_json)
        with open(pred_json, 'w') as f:
            json.dump(jdict, f)

        try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
            from pycocotools.coco import COCO
            from pycocotools.cocoeval import COCOeval

            anno = COCO(anno_json)  # init annotations api
            pred = anno.loadRes(pred_json)  # init predictions api
            eval = COCOeval(anno, pred, 'bbox')
            if is_coco:
                eval.params.imgIds = [
                    int(Path(x).stem) for x in dataloader.dataset.img_files
                ]  # image IDs to evaluate
            eval.evaluate()
            eval.accumulate()
            eval.summarize()
            map, map50 = eval.stats[:
                                    2]  # update results ([email protected]:0.95, [email protected])
        except Exception as e:
            print(f'pycocotools unable to run: {e}')

    # Return results
    if not training:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
        print(f"Results saved to {save_dir}{s}")
    model.float()  # for training
    maps = np.zeros(nc) + map
    for i, c in enumerate(ap_class):
        maps[c] = ap[i]
    return (mp, mr, map50, map,
            *(loss.cpu() / len(dataloader)).tolist()), maps, t
コード例 #2
0
    def future_v(self):
        self.stat_speed = self.vf * torch.exp(torch.div(-1,self.a_var+self.TINY)\
                                *torch.pow(torch.div(self.current_densities,self.rhocr+self.TINY)+self.TINY,self.a_var))
        self.stat_speed = torch.clamp(self.stat_speed,
                                      min=self.vmin,
                                      max=self.vmax)
        try:
            if self.print_count % self.print_every == 0:
                wandb.log(
                    {"vf": wandb.Histogram(self.vf.cpu().detach().numpy())})
                wandb.log({
                    "a_var":
                    wandb.Histogram(self.a_var.cpu().detach().numpy())
                })
                wandb.log({
                    "rhocr":
                    wandb.Histogram(self.rhocr.cpu().detach().numpy())
                })
                wandb.log(
                    {"g": wandb.Histogram(self.g_var.cpu().detach().numpy())})
                wandb.log(
                    {"q0": wandb.Histogram(self.q0.cpu().detach().numpy())})
                wandb.log({
                    "rhoNp1":
                    wandb.Histogram(self.rhoNp1.cpu().detach().numpy())
                })
                wandb.log({
                    "current_velocities":
                    wandb.Histogram(
                        self.current_velocities.cpu().detach().numpy())
                })
                wandb.log({
                    'mean_current_velocities':
                    self.current_velocities.cpu().detach().numpy().mean(),
                    'mean_current_densities':
                    self.current_densities.cpu().detach().numpy().mean(),
                    'mean_current_flows':
                    self.current_flows.cpu().detach().numpy().mean(),
                    'mean_current_onramp':
                    self.current_onramp.cpu().detach().numpy().mean(),
                    'mean_current_offramp':
                    self.current_offramp.cpu().detach().numpy().mean(),
                    'mean_v0':
                    self.v0.cpu().detach().numpy().mean(),
                    'mean_q0':
                    self.q0.cpu().detach().numpy().mean(),
                    'mean_rhoNp1':
                    self.rhoNp1.cpu().detach().numpy().mean()
                })
                wandb.log({
                    "current_densities":
                    wandb.Histogram(
                        self.current_densities.cpu().detach().numpy())
                })
                wandb.log({
                    "current_flows":
                    wandb.Histogram(self.current_flows.cpu().detach().numpy())
                })
                wandb.log({
                    "current_onramp":
                    wandb.Histogram(self.current_onramp.cpu().detach().numpy())
                })
                wandb.log({
                    "current_offramp":
                    wandb.Histogram(
                        self.current_offramp.cpu().detach().numpy())
                })

                wandb.log({
                    "current_r_4":
                    wandb.Histogram(
                        self.current_onramp[:, 3].cpu().detach().numpy())
                })
                wandb.log({
                    "current_s_2":
                    wandb.Histogram(
                        self.current_offramp[:, 1].cpu().detach().numpy())
                })

                wandb.log({
                    "current_flows_1":
                    wandb.Histogram(
                        self.current_flows[:, 0].cpu().detach().numpy())
                })
                wandb.log({
                    "current_flows_2":
                    wandb.Histogram(
                        self.current_flows[:, 1].cpu().detach().numpy())
                })
                wandb.log({
                    "current_flows_3":
                    wandb.Histogram(
                        self.current_flows[:, 2].cpu().detach().numpy())
                })
                wandb.log({
                    "current_flows_4":
                    wandb.Histogram(
                        self.current_flows[:, 3].cpu().detach().numpy())
                })

                wandb.log({
                    "current_flows_1_to_2":
                    wandb.Histogram(
                        self.current_flows[:, 1].cpu().detach().numpy() -
                        self.current_flows[:, 0].cpu().detach().numpy())
                })
                wandb.log({
                    "current_flows_2_to_3":
                    wandb.Histogram(
                        self.current_flows[:, 2].cpu().detach().numpy() -
                        self.current_flows[:, 1].cpu().detach().numpy())
                })
                wandb.log({
                    "current_flows_3_to_4":
                    wandb.Histogram(
                        self.current_flows[:, 3].cpu().detach().numpy() -
                        self.current_flows[:, 2].cpu().detach().numpy())
                })

                wandb.log({
                    "stat_speed":
                    wandb.Histogram(self.stat_speed.cpu().detach().numpy())
                })
                wandb.log(
                    {"v0": wandb.Histogram(self.v0.cpu().detach().numpy())})
                wandb.log(
                    {"q0": wandb.Histogram(self.q0.cpu().detach().numpy())})
                wandb.log({
                    "rhoNp1":
                    wandb.Histogram(self.rhoNp1.cpu().detach().numpy())
                })
                q_max = self.rhocr * self.vf * torch.exp(
                    torch.div(-1, self.a_var + self.TINY))
                wandb.log(
                    {"q_max": wandb.Histogram(q_max.cpu().detach().numpy())})
                wandb.log({
                    "t_var":
                    wandb.Histogram(self.t_var.cpu().detach().numpy())
                })
                wandb.log(
                    {"tau": wandb.Histogram(self.tau.cpu().detach().numpy())})
                wandb.log(
                    {"nu": wandb.Histogram(self.nu.cpu().detach().numpy())})
                wandb.log({
                    "delta":
                    wandb.Histogram(self.delta.cpu().detach().numpy())
                })
                wandb.log({
                    "kappa":
                    wandb.Histogram(self.kappa.cpu().detach().numpy())
                })
                wandb.log({
                    "cap_delta":
                    wandb.Histogram(self.cap_delta.cpu().detach().numpy())
                })
                wandb.log({
                    "lambda_var":
                    wandb.Histogram(self.lambda_var.cpu().detach().numpy())
                })
                wandb.log({
                    "epsv":
                    wandb.Histogram(self.epsv.cpu().detach().numpy())
                })
        except Exception as e:
            print(e)

        return self.current_velocities + (torch.div(self.t_var,self.tau+self.TINY)) * (self.stat_speed - self.current_velocities )  \
              + (torch.div(self.t_var,self.cap_delta) * self.current_velocities * (self.prev_velocities - self.current_velocities)) \
              - (torch.div(self.nu*self.t_var, (self.tau*self.cap_delta)) * torch.div( (self.next_densities - self.current_densities), (self.current_densities+self.kappa)) ) \
              - (torch.div( (self.delta*self.t_var) , (self.cap_delta * self.lambda_var+self.TINY) ) * torch.div( (self.current_onramp*self.current_velocities),(self.current_densities+self.kappa) ) ) \
              + self.epsv
コード例 #3
0
    }
else:
    result = {
        'count_acc': stats['count'] / stats['count_tot'],
        'exist_acc': stats['exist'] / stats['exist_tot'],
        'compare_num_acc': stats['compare_num'] / stats['compare_num_tot'],
        'compare_attr_acc': stats['compare_attr'] / stats['compare_attr_tot'],
        'query_acc': stats['query'] / stats['query_tot'],
        'program_acc': stats['correct_prog'] / stats['total'],
        'overall_acc': stats['correct_ans'] / stats['total']
    }
logging.debug(result)
if opt.visualize_training_wandb:
    # Results (and corresponding options) are available on wandb sever in test_opt_{ts}.json file dump
    val_qfn = opt.clevr_val_question_filename.split(".")[0]
    wandb.log({"Batch Stats": table})
    wandb.log({f"Sample Stats": sample_table})
    wandb.log({f"Sample Stats [{val_qfn}] (Correct)": sample_table_correct})
    wandb.log(
        {f"Sample Stats [{val_qfn}] (Incorrect)": sample_table_incorrect})
    wandb.log({
        "q_length_corr":
        wandb.Histogram(sample_stats['q_lens_corr'], num_bins=10)
    })
    wandb.log({
        "q_length_incorr":
        wandb.Histogram(sample_stats['q_lens_incorr'], num_bins=10)
    })
    log_params(opt, result)
else:
    # Record local copies of results #
コード例 #4
0
    wandb.run.summary["run_id"] = os.path.basename(wandb.run.dir)
    # Training
    for i, (img, lbl) in enumerate(train_dataset):
        model_train(img, lbl, model, optimizer_us)

    # Validate
    for i, (img, lbl) in enumerate(test_dataset):
        model_validate_us(img, lbl, model)

    model.save(model_path, overwrite=True)

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    if epoch % 1 == 0:
        print(
            template.format(epoch + 1, round(train_loss.result().numpy() * 1,
                                             2),
                            round(train_acc.result().numpy() * 100, 2),
                            round(valid_loss.result().numpy(), 2),
                            round(valid_acc.result().numpy() * 100, 2)))

        wandb.log(dict(loss=round(train_loss.result().numpy(), 2),
                       accuracy=round(train_acc.result().numpy() * 100, 2),
                       test_loss=round(valid_loss.result().numpy(), 2),
                       test_accuracy=round(valid_acc.result().numpy() * 100,
                                           2)),
                  step=epoch)

    train_loss.reset_states(),
    valid_loss.reset_states()
    train_acc.reset_states()
    valid_acc.reset_states()
コード例 #5
0
            pred_seqs.extend(
                ' '.join(output_vocab.itos[tok] for tok in acc_tensor[i]
                         if output_vocab.itos[tok] not in specials)
                for i in range(len(batch))
            )

            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()
            avg_loss += loss.item() / len(batch)
            # loss_another_epoch = F.nll_loss(output_another.squeeze(1), target_variables.squeeze(0)).item() / len(batch)
        # compute bleu
        epoch_bleu = run_perl_script_and_parse_result('\n'.join(tgt_seqs),
                                                      '\n'.join(pred_seqs),
                                                      perl_script_path)
        print("TARGET = {}\nPREDICTED = {}".format(', '.join(tgt_seqs), ', '.join(pred_seqs)))
        if epoch_bleu:
            wandb.log({'bleu': epoch_bleu.bleu, 'avg_loss': avg_loss, 'another_loss': loss_another_epoch})
            tk0.set_postfix(loss=avg_loss, train_bleu=epoch_bleu.bleu, another_loss=loss_another_epoch)
        else:
            wandb.log({'avg_loss': avg_loss})
            tk0.set_postfix(loss=avg_loss)
    wandb.save('model_nmt2.h5')

    #   # Save model after every epoch (Optional)
    # torch.save({"encoder":encoder.state_dict(),
    #             "decoder":decoder.state_dict(),
    #             "e_optimizer":encoder_optimizer.state_dict(),
    #             "d_optimizer":decoder_optimizer},
    #            "./model.pt")
コード例 #6
0
def plot_roc_curve(label, predict):
    fpr, tpr, thresholds = roc_curve(label, predict)
    plt.plot(fpr, tpr)
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    wandb.log({"Plot ROC curve": plt})
コード例 #7
0
    def train(self,
              train_dataset,
              output_dir,
              show_running_loss=True,
              eval_df=None,
              verbose=True,
              **kwargs):
        """
        Trains the model on train_dataset.

        Utility function to be used by the train_model() method. Not intended to be used directly.
        """

        device = self.device
        model = self.model
        args = self.args

        tb_writer = SummaryWriter(logdir=args["tensorboard_dir"])
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args["train_batch_size"])

        t_total = len(train_dataloader) // args[
            "gradient_accumulation_steps"] * args["num_train_epochs"]

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                args["weight_decay"],
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

        warmup_steps = math.ceil(t_total * args["warmup_ratio"])
        args["warmup_steps"] = warmup_steps if args[
            "warmup_steps"] == 0 else args["warmup_steps"]

        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args["learning_rate"],
            eps=args["adam_epsilon"],
        )
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args["warmup_steps"],
            num_training_steps=t_total)

        if args["fp16"]:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )

            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args["fp16_opt_level"])

        if args["n_gpu"] > 1:
            model = torch.nn.DataParallel(model)

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        model.zero_grad()
        train_iterator = trange(int(args["num_train_epochs"]),
                                desc="Epoch",
                                disable=args["silent"],
                                mininterval=0)
        epoch_number = 0
        best_eval_metric = None
        early_stopping_counter = 0
        steps_trained_in_current_epoch = 0
        epochs_trained = 0

        if args["model_name"] and os.path.exists(args["model_name"]):
            try:
                # set global_step to gobal_step of last saved checkpoint from model path
                checkpoint_suffix = args["model_name"].split("/")[-1].split(
                    "-")
                if len(checkpoint_suffix) > 2:
                    checkpoint_suffix = checkpoint_suffix[1]
                else:
                    checkpoint_suffix = checkpoint_suffix[-1]
                global_step = int(checkpoint_suffix)
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    args["gradient_accumulation_steps"])
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    args["gradient_accumulation_steps"])

                logger.info(
                    "   Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("   Continuing training from epoch %d",
                            epochs_trained)
                logger.info("   Continuing training from global step %d",
                            global_step)
                logger.info(
                    "   Will skip the first %d steps in the current epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                logger.info("   Starting fine-tuning.")

        if args["evaluate_during_training"]:
            training_progress_scores = self._create_training_progress_scores(
                **kwargs)
        if args["wandb_project"]:
            wandb.init(project=args["wandb_project"],
                       config={**args},
                       **args["wandb_kwargs"])
            wandb.watch(self.model)

        model.train()
        for _ in train_iterator:
            if epochs_trained > 0:
                epochs_trained -= 1
                continue
            # epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Current iteration",
                         disable=args["silent"])):
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue
                batch = tuple(t.to(device) for t in batch)

                inputs = self._get_inputs_dict(batch)

                outputs = model(**inputs)
                # model outputs are always tuple in pytorch-transformers (see doc)
                loss = outputs[0]

                if args["n_gpu"] > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training

                current_loss = loss.item()

                if show_running_loss:
                    print("\rRunning loss: %f" % loss, end="")

                if args["gradient_accumulation_steps"] > 1:
                    loss = loss / args["gradient_accumulation_steps"]

                if args["fp16"]:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    # torch.nn.utils.clip_grad_norm_(
                    #     amp.master_params(optimizer), args["max_grad_norm"]
                    # )
                else:
                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(
                    #     model.parameters(), args["max_grad_norm"]
                    # )

                tr_loss += loss.item()
                if (step + 1) % args["gradient_accumulation_steps"] == 0:
                    if args["fp16"]:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            args["max_grad_norm"])
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args["max_grad_norm"])
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args["logging_steps"] > 0 and global_step % args[
                            "logging_steps"] == 0:
                        # Log metrics
                        tb_writer.add_scalar("lr",
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar(
                            "loss",
                            (tr_loss - logging_loss) / args["logging_steps"],
                            global_step,
                        )
                        logging_loss = tr_loss
                        if args["wandb_project"]:
                            wandb.log({
                                "Training loss": current_loss,
                                "lr": scheduler.get_lr()[0],
                                "global_step": global_step,
                            })

                    if args["save_steps"] > 0 and global_step % args[
                            "save_steps"] == 0:
                        # Save model checkpoint
                        output_dir_current = os.path.join(
                            output_dir, "checkpoint-{}".format(global_step))

                        self._save_model(output_dir_current,
                                         optimizer,
                                         scheduler,
                                         model=model)

                    if args["evaluate_during_training"] and (
                            args["evaluate_during_training_steps"] > 0
                            and global_step %
                            args["evaluate_during_training_steps"] == 0):
                        # Only evaluate when single GPU otherwise metrics may not average well
                        results, _, _ = self.eval_model(
                            eval_df,
                            verbose=verbose
                            and args["evaluate_during_training_verbose"],
                            **kwargs)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)

                        output_dir_current = os.path.join(
                            output_dir, "checkpoint-{}".format(global_step))

                        os.makedirs(output_dir_current, exist_ok=True)

                        if args["save_eval_checkpoints"]:
                            self._save_model(output_dir_current,
                                             optimizer,
                                             scheduler,
                                             model=model,
                                             results=results)

                        training_progress_scores["global_step"].append(
                            global_step)
                        training_progress_scores["train_loss"].append(
                            current_loss)
                        for key in results:
                            training_progress_scores[key].append(results[key])
                        report = pd.DataFrame(training_progress_scores)
                        report.to_csv(
                            os.path.join(args["output_dir"],
                                         "training_progress_scores.csv"),
                            index=False,
                        )

                        if args["wandb_project"]:
                            wandb.log(
                                self._get_last_metrics(
                                    training_progress_scores))

                        if not best_eval_metric:
                            best_eval_metric = results[
                                args["early_stopping_metric"]]
                            self._save_model(args["best_model_dir"],
                                             optimizer,
                                             scheduler,
                                             model=model,
                                             results=results)
                        if best_eval_metric and args[
                                "early_stopping_metric_minimize"]:
                            if (results[args["early_stopping_metric"]] -
                                    best_eval_metric <
                                    args["early_stopping_delta"]):
                                best_eval_metric = results[
                                    args["early_stopping_metric"]]
                                self._save_model(args["best_model_dir"],
                                                 optimizer,
                                                 scheduler,
                                                 model=model,
                                                 results=results)
                                early_stopping_counter = 0
                            else:
                                if args["use_early_stopping"]:
                                    if early_stopping_counter < args[
                                            "early_stopping_patience"]:
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(
                                                f" No improvement in {args['early_stopping_metric']}"
                                            )
                                            logger.info(
                                                f" Current step: {early_stopping_counter}"
                                            )
                                            logger.info(
                                                f" Early stopping patience: {args['early_stopping_patience']}"
                                            )
                                    else:
                                        if verbose:
                                            logger.info(
                                                f" Patience of {args['early_stopping_patience']} steps reached"
                                            )
                                            logger.info(
                                                " Training terminated.")
                                            train_iterator.close()
                                        return global_step, tr_loss / global_step
                        else:
                            if (results[args["early_stopping_metric"]] -
                                    best_eval_metric >
                                    args["early_stopping_delta"]):
                                best_eval_metric = results[
                                    args["early_stopping_metric"]]
                                self._save_model(args["best_model_dir"],
                                                 optimizer,
                                                 scheduler,
                                                 model=model,
                                                 results=results)
                                early_stopping_counter = 0
                            else:
                                if args["use_early_stopping"]:
                                    if early_stopping_counter < args[
                                            "early_stopping_patience"]:
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(
                                                f" No improvement in {args['early_stopping_metric']}"
                                            )
                                            logger.info(
                                                f" Current step: {early_stopping_counter}"
                                            )
                                            logger.info(
                                                f" Early stopping patience: {args['early_stopping_patience']}"
                                            )
                                    else:
                                        if verbose:
                                            logger.info(
                                                f" Patience of {args['early_stopping_patience']} steps reached"
                                            )
                                            logger.info(
                                                " Training terminated.")
                                            train_iterator.close()
                                        return global_step, tr_loss / global_step

            epoch_number += 1
            output_dir_current = os.path.join(
                output_dir,
                "checkpoint-{}-epoch-{}".format(global_step, epoch_number))

            if args["save_model_every_epoch"] or args[
                    "evaluate_during_training"]:
                os.makedirs(output_dir_current, exist_ok=True)

            if args["save_model_every_epoch"]:
                self._save_model(output_dir_current,
                                 optimizer,
                                 scheduler,
                                 model=model)

            if args["evaluate_during_training"]:
                results, _, _ = self.eval_model(
                    eval_df,
                    verbose=verbose
                    and args["evaluate_during_training_verbose"],
                    **kwargs)

                self._save_model(output_dir_current,
                                 optimizer,
                                 scheduler,
                                 results=results)

                training_progress_scores["global_step"].append(global_step)
                training_progress_scores["train_loss"].append(current_loss)
                for key in results:
                    training_progress_scores[key].append(results[key])
                report = pd.DataFrame(training_progress_scores)
                report.to_csv(os.path.join(args["output_dir"],
                                           "training_progress_scores.csv"),
                              index=False)

                if args["wandb_project"]:
                    wandb.log(self._get_last_metrics(training_progress_scores))

                if not best_eval_metric:
                    best_eval_metric = results[args["early_stopping_metric"]]
                    self._save_model(args["best_model_dir"],
                                     optimizer,
                                     scheduler,
                                     model=model,
                                     results=results)
                if best_eval_metric and args["early_stopping_metric_minimize"]:
                    if results[args[
                            "early_stopping_metric"]] - best_eval_metric < args[
                                "early_stopping_delta"]:
                        best_eval_metric = results[
                            args["early_stopping_metric"]]
                        self._save_model(args["best_model_dir"],
                                         optimizer,
                                         scheduler,
                                         model=model,
                                         results=results)
                        early_stopping_counter = 0
                    else:
                        if args["use_early_stopping"] and args[
                                "early_stopping_consider_epochs"]:
                            if early_stopping_counter < args[
                                    "early_stopping_patience"]:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(
                                        f" No improvement in {args['early_stopping_metric']}"
                                    )
                                    logger.info(
                                        f" Current step: {early_stopping_counter}"
                                    )
                                    logger.info(
                                        f" Early stopping patience: {args['early_stopping_patience']}"
                                    )
                            else:
                                if verbose:
                                    logger.info(
                                        f" Patience of {args['early_stopping_patience']} steps reached"
                                    )
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return global_step, tr_loss / global_step
                else:
                    if results[args[
                            "early_stopping_metric"]] - best_eval_metric > args[
                                "early_stopping_delta"]:
                        best_eval_metric = results[
                            args["early_stopping_metric"]]
                        self._save_model(args["best_model_dir"],
                                         optimizer,
                                         scheduler,
                                         model=model,
                                         results=results)
                        early_stopping_counter = 0
                        early_stopping_counter = 0
                    else:
                        if args["use_early_stopping"] and args[
                                "early_stopping_consider_epochs"]:
                            if early_stopping_counter < args[
                                    "early_stopping_patience"]:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(
                                        f" No improvement in {args['early_stopping_metric']}"
                                    )
                                    logger.info(
                                        f" Current step: {early_stopping_counter}"
                                    )
                                    logger.info(
                                        f" Early stopping patience: {args['early_stopping_patience']}"
                                    )
                            else:
                                if verbose:
                                    logger.info(
                                        f" Patience of {args['early_stopping_patience']} steps reached"
                                    )
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return global_step, tr_loss / global_step

        return global_step, tr_loss / global_step
コード例 #8
0
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 0.1}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    acc_init = 0
    max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    interval_iter = max_iter // 10
    iter_num = 0

    netF.train()
    netB.train()
    netC.train()

    # wandb watching
    wandb.watch(netF)
    wandb.watch(netB)
    wandb.watch(netC)

    while iter_num < max_iter:
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, labels_source = iter_source.next()

        if inputs_source.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source, labels_source = inputs_source.cuda(
        ), labels_source.cuda()
        outputs_source = netC(netB(netF(inputs_source)))
        classifier_loss = CrossEntropyLabelSmooth(
            num_classes=args.class_num, epsilon=args.smooth)(outputs_source,
                                                             labels_source)

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            netC.eval()
            if args.dset == 'VISDA18' or args.dset == 'VISDA-C':
                acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF,
                                             netB, netC, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name_src, iter_num, max_iter,
                    acc_s_te) + '\n' + acc_list
                wandb.log({"accuracy": acc_s_te})
            else:
                acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB,
                                      netC, False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name_src, iter_num, max_iter, acc_s_te)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if acc_s_te >= acc_init:
                acc_init = acc_s_te
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()

            netF.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
    torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))

    return netF, netB, netC
コード例 #9
0
ファイル: __init__.py プロジェクト: repson/client
def log(tf_summary_str, **kwargs):
    namespace = kwargs.get("namespace")
    if "namespace" in kwargs:
        del kwargs["namespace"]
    wandb.log(tf_summary_to_dict(tf_summary_str, namespace), **kwargs)
コード例 #10
0
ファイル: fedavg_trainer.py プロジェクト: zhuzhu603/FedML
    def local_test_on_all_clients(self, model_global, round_idx):

        if self.args.dataset in ["stackoverflow_lr", "stackoverflow_nwp"]:
            # due to the amount of test set, only abount 10000 samples are tested each round
            testlist = random.sample(range(0, self.args.client_num_in_total),
                                     100)
            logging.info(
                "################local_test_round_{}_on_clients : {}".format(
                    round_idx, str(testlist)))
        else:
            logging.info(
                "################local_test_on_all_clients : {}".format(
                    round_idx))
            testlist = list(range(self.args.client_num_in_total))

        train_metrics = {
            'num_samples': [],
            'num_correct': [],
            'precisions': [],
            'recalls': [],
            'losses': []
        }

        test_metrics = {
            'num_samples': [],
            'num_correct': [],
            'precisions': [],
            'recalls': [],
            'losses': []
        }

        client = self.client_list[0]

        for client_idx in testlist:
            """
            Note: for datasets like "fed_CIFAR100" and "fed_shakespheare",
            the training client number is larger than the testing client number
            """
            if self.test_data_local_dict[client_idx] is None:
                continue
            client.update_local_dataset(
                0, self.train_data_local_dict[client_idx],
                self.test_data_local_dict[client_idx],
                self.train_data_local_num_dict[client_idx])
            # train data
            train_local_metrics = client.local_test(model_global, False)
            train_metrics['num_samples'].append(
                copy.deepcopy(train_local_metrics['test_total']))
            train_metrics['num_correct'].append(
                copy.deepcopy(train_local_metrics['test_correct']))
            train_metrics['losses'].append(
                copy.deepcopy(train_local_metrics['test_loss']))

            # test data
            test_local_metrics = client.local_test(model_global, True)
            test_metrics['num_samples'].append(
                copy.deepcopy(test_local_metrics['test_total']))
            test_metrics['num_correct'].append(
                copy.deepcopy(test_local_metrics['test_correct']))
            test_metrics['losses'].append(
                copy.deepcopy(test_local_metrics['test_loss']))

            if self.args.dataset == "stackoverflow_lr":
                train_metrics['precisions'].append(
                    copy.deepcopy(train_local_metrics['test_precision']))
                train_metrics['recalls'].append(
                    copy.deepcopy(train_local_metrics['test_recall']))
                test_metrics['precisions'].append(
                    copy.deepcopy(test_local_metrics['test_precision']))
                test_metrics['recalls'].append(
                    copy.deepcopy(test_local_metrics['test_recall']))
                # due to the amount of test set, only abount 10000 samples are tested each round
                if sum(test_metrics['num_samples']) >= 10000:
                    break
            """
            Note: CI environment is CPU-based computing. 
            The training speed for RNN training is to slow in this setting, so we only test a client to make sure there is no programming error.
            """
            if self.args.ci == 1:
                break

        # test on training dataset
        train_acc = sum(train_metrics['num_correct']) / sum(
            train_metrics['num_samples'])
        train_loss = sum(train_metrics['losses']) / sum(
            train_metrics['num_samples'])
        train_precision = sum(train_metrics['precisions']) / sum(
            train_metrics['num_samples'])
        train_recall = sum(train_metrics['recalls']) / sum(
            train_metrics['num_samples'])

        # test on test dataset
        test_acc = sum(test_metrics['num_correct']) / sum(
            test_metrics['num_samples'])
        test_loss = sum(test_metrics['losses']) / sum(
            test_metrics['num_samples'])
        test_precision = sum(test_metrics['precisions']) / sum(
            test_metrics['num_samples'])
        test_recall = sum(test_metrics['recalls']) / sum(
            test_metrics['num_samples'])

        if self.args.dataset == "stackoverflow_lr":
            stats = {
                'training_acc': train_acc,
                'training_precision': train_precision,
                'training_recall': train_recall,
                'training_loss': train_loss
            }
            wandb.log({"Train/Acc": train_acc, "round": round_idx})
            wandb.log({"Train/Pre": train_precision, "round": round_idx})
            wandb.log({"Train/Rec": train_recall, "round": round_idx})
            wandb.log({"Train/Loss": train_loss, "round": round_idx})
            logging.info(stats)

            stats = {
                'test_acc': test_acc,
                'test_precision': test_precision,
                'test_recall': test_recall,
                'test_loss': test_loss
            }
            wandb.log({"Test/Acc": test_acc, "round": round_idx})
            wandb.log({"Test/Pre": test_precision, "round": round_idx})
            wandb.log({"Test/Rec": test_recall, "round": round_idx})
            wandb.log({"Test/Loss": test_loss, "round": round_idx})
            logging.info(stats)

        else:
            stats = {'training_acc': train_acc, 'training_loss': train_loss}
            wandb.log({"Train/Acc": train_acc, "round": round_idx})
            wandb.log({"Train/Loss": train_loss, "round": round_idx})
            logging.info(stats)

            stats = {'test_acc': test_acc, 'test_loss': test_loss}
            wandb.log({"Test/Acc": test_acc, "round": round_idx})
            wandb.log({"Test/Loss": test_loss, "round": round_idx})
            logging.info(stats)
コード例 #11
0
def serve(config, gp_server):
    config.to_defaults()

    # Create Subtensor connection
    subtensor = bittensor.subtensor(config=config)

    # Load/Create our bittensor wallet.
    wallet = bittensor.wallet(config=config).create().register()

    # Load/Sync/Save our metagraph.
    metagraph = bittensor.metagraph(subtensor=subtensor).load().sync().save()

    # Instantiate the model we are going to serve on the network.
    # Creating a threading lock for updates to the model
    mutex = Lock()
    gp_server = gp_server.to(gp_server.device)

    # Create our optimizer.
    optimizer = torch.optim.SGD(
        [{
            "params": gp_server.parameters()
        }],
        lr=config.neuron.learning_rate,
        momentum=config.neuron.momentum,
    )

    timecheck = {}

    # Define our forward function.
    def forward_text(inputs_x):
        r""" Forward function that is called when the axon recieves a forward request from other peers
            Args:
                inputs_x ( :obj:`torch.Tensor`, `required`):
                    torch inputs to be forward processed.

            Returns:
                outputs (:obj:`torch.FloatTensor`):
                    The nucleus's outputs as a torch tensor of shape [batch_size, sequence_len, __network_dim__]
        """
        return gp_server.encode_forward(inputs_x.to(gp_server.device))

    # Define our backward function.
    def backward_text(inputs_x, grads_dy):
        r"""Backwards function that is called when the axon recieves a backwards request from other peers.
            Updates the server parameters with gradients through the chain.

            Args:
                inputs_x ( :obj:`torch.Tensor`, `required`):
                    torch inputs from previous forward call.
                grads_dy ( :obj:`torch.Tensor`, `required`):
                    torch grads of forward output.
                    
        """
        # -- normalized grads --
        grads_dy = grads_dy / (grads_dy.sum() + 0.00001)

        with mutex:
            outputs_y = gp_server.encode_forward(inputs_x.to(gp_server.device))
            with torch.autograd.set_detect_anomaly(True):
                torch.autograd.backward(
                    tensors=[outputs_y],
                    grad_tensors=[grads_dy.to(gp_server.device)],
                    retain_graph=True)
            logger.info('Backwards axon gradient applied')

        gp_server.backward_gradients += inputs_x.size(0)

    def priority(pubkey: str, request_type: bittensor.proto.RequestType,
                 inputs_x) -> float:
        r"""Calculates the priority on requests based on stake and size of input

            Args:
                pubkey ( str, `required`):
                    The public key of the caller.
                inputs_x ( :obj:`torch.Tensor`, `required`):
                    torch inputs to be forward processed.
                request_type ( bittensor.proto.RequestType, `required`):
                    the request type ('FORWARD' or 'BACKWARD').
        """
        uid = metagraph.hotkeys.index(pubkey)
        priority = metagraph.S[uid].item() / sys.getsizeof(inputs_x)

        return priority

    def blacklist(pubkey: str,
                  request_type: bittensor.proto.RequestType) -> bool:
        r"""Axon security blacklisting, used to blacklist message from low stake members
            Args:
                pubkey ( str, `required`):
                    The public key of the caller.
                request_type ( bittensor.proto.RequestType, `required`):
                    the request type ('FORWARD' or 'BACKWARD').
        """

        # Check for stake
        def stake_check() -> bool:
            # If we allow non-registered requests return False = not blacklisted.
            is_registered = pubkey in metagraph.hotkeys
            if not is_registered:
                if config.neuron.blacklist_allow_non_registered:
                    return False
                else:
                    return True

            # Check stake.
            uid = metagraph.hotkeys.index(pubkey)
            if request_type == bittensor.proto.RequestType.FORWARD:
                if metagraph.S[uid].item(
                ) < config.neuron.blacklist.stake.forward:
                    return True
                else:
                    return False

            elif request_type == bittensor.proto.RequestType.BACKWARD:
                if metagraph.S[uid].item(
                ) < config.neuron.blacklist.stake.backward:
                    return True
                else:
                    return False

        # Check for time
        def time_check():
            current_time = datetime.now()
            if pubkey in timecheck.keys():
                prev_time = timecheck[pubkey]
                if current_time - prev_time >= timedelta(
                        seconds=config.neuron.blacklist.time):
                    timecheck[pubkey] = current_time
                    return False
                else:
                    timecheck[pubkey] = current_time
                    return True
            else:
                timecheck[pubkey] = current_time
                return False

        # Black list or not
        if stake_check() or time_check():
            return True
        else:
            return False

    # Create our axon server
    axon = bittensor.axon(wallet=wallet,
                          forward_text=forward_text,
                          backward_text=backward_text,
                          blacklist=blacklist,
                          priority=priority)

    # Training Data
    dataset = bittensor.dataset(config=config)

    # load our old model
    if config.neuron.no_restart != True:
        gp_server.load(config.neuron.full_path)

    if config.wandb.api_key != 'default':
        # --- Init Wandb.
        bittensor.wandb(config=config,
                        cold_pubkey=wallet.coldkeypub.ss58_address,
                        hot_pubkey=wallet.hotkey.ss58_address,
                        root_dir=config.neuron.full_path)

    nn = subtensor.neuron_for_pubkey(wallet.hotkey.ss58_address)

    # --- last sync block
    last_sync_block = subtensor.get_current_block()
    last_set_block = last_sync_block

    # -- Main Training loop --
    try:
        # -- download files from the mountain
        data = next(dataset)

        # --- creating our chain weights
        chain_weights = torch.zeros(metagraph.n)
        uid = nn.uid
        chain_weights[uid] = 1

        # --  serve axon to the network.
        axon.start().serve(subtensor=subtensor)

        while True:
            # --- Run
            current_block = subtensor.get_current_block()
            end_block = current_block + config.neuron.blocks_per_epoch
            interation = 0

            # --- Training step.
            while end_block >= current_block:
                if current_block != subtensor.get_current_block():
                    loss, _ = gp_server(next(dataset).to(gp_server.device))
                    if interation > 0:
                        losses += loss
                    else:
                        losses = loss
                    interation += 1
                    current_block = subtensor.get_current_block()

            #Custom learning rate
            if gp_server.backward_gradients > 0:
                optimizer.param_groups[0]['lr'] = 1 / (
                    gp_server.backward_gradients)
            else:
                optimizer.param_groups[0]['lr'] = 0.1

            # --- Update parameters
            if interation != 0 or gp_server.backward_gradients != 0:
                with mutex:
                    logger.info('Backpropagation Started')
                    if interation != 0:
                        losses.backward()
                    clip_grad_norm_(gp_server.parameters(), 1.0)

                    optimizer.step()
                    optimizer.zero_grad()
                    logger.info('Backpropagation Successful: Model updated')

            nn = subtensor.neuron_for_pubkey(wallet.hotkey.ss58_address)

            gp_server.backward_gradients = 0
            # --- logging data
            wandb_data = {
                'block': end_block,
                'loss': losses.cpu().item() / interation,
                'stake': nn.stake,
                'rank': nn.rank,
                'incentive': nn.incentive,
                'trust': nn.trust,
                'consensus': nn.consensus,
                'incentive': nn.incentive,
                'dividends': nn.dividends,
                'emission': nn.emission,
            }
            bittensor.__console__.print('[green]Current Status:[/green]',
                                        wandb_data)

            # Add additional wandb data for axon, metagraph etc.
            if config.wandb.api_key != 'default':

                df = pandas.concat([
                    bittensor.utils.indexed_values_to_dataframe(
                        prefix='w_i_{}'.format(nn.uid),
                        index=metagraph.uids,
                        values=metagraph.W[:, uid]),
                    bittensor.utils.indexed_values_to_dataframe(
                        prefix='s_i'.format(nn.uid),
                        index=metagraph.uids,
                        values=metagraph.S),
                    axon.to_dataframe(metagraph=metagraph),
                ],
                                   axis=1)
                df['uid'] = df.index
                stats_data_table = wandb.Table(dataframe=df)
                wandb_info_axon = axon.to_wandb()
                wandb.log({
                    **wandb_data,
                    **wandb_info_axon
                },
                          step=current_block)
                wandb.log({'stats': stats_data_table}, step=current_block)
                wandb.log({
                    'axon_query_times':
                    wandb.plot.scatter(stats_data_table,
                                       "uid",
                                       "axon_query_time",
                                       title="Axon Query time by UID")
                })
                wandb.log({
                    'in_weights':
                    wandb.plot.scatter(stats_data_table,
                                       "uid",
                                       'w_i_{}'.format(nn.uid),
                                       title="Inward weights by UID")
                })
                wandb.log({
                    'stake':
                    wandb.plot.scatter(stats_data_table,
                                       "uid",
                                       's_i',
                                       title="Stake by UID")
                })

            # Save the model
            gp_server.save(config.neuron.full_path)

            if current_block - last_set_block > config.neuron.blocks_per_set_weights:

                # --- Setting weights
                try:
                    last_set_block = current_block
                    # Set self weights to maintain activity.
                    chain_weights = torch.zeros(metagraph.n)
                    chain_weights[uid] = 1
                    did_set = subtensor.set_weights(
                        uids=metagraph.uids,
                        weights=chain_weights,
                        wait_for_inclusion=False,
                        wallet=wallet,
                    )

                    if did_set:
                        logger.success('Successfully set weights on the chain')
                    else:
                        logger.error(
                            'Failed to set weights on chain. (Timeout)')
                except Exception as e:
                    logger.error(
                        'Failure setting weights on chain with error: {}', e)

            if current_block - last_sync_block > config.neuron.metagraph_sync:
                metagraph.sync()
                last_sync_block = current_block

    except KeyboardInterrupt:
        # --- User ended session ----
        axon.stop()
    except Exception as e:
        # --- Unknown error ----
        logger.exception('Unknown exception: {} with traceback {}', e,
                         traceback.format_exc())
コード例 #12
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    # Wandb Setup
    if FLAGS.use_wandb:
        pathlib.Path(FLAGS.wandb_dir).mkdir(parents=True, exist_ok=True)
        wandb_args = dict(project=FLAGS.project,
                          entity='uncertainty-baselines',
                          dir=FLAGS.wandb_dir,
                          reinit=True,
                          name=FLAGS.exp_name,
                          group=FLAGS.exp_group)
        wandb_run = wandb.init(**wandb_args)
        wandb.config.update(FLAGS, allow_val_change=True)
        output_dir = str(
            os.path.join(
                FLAGS.output_dir,
                datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')))
    else:
        wandb_run = None
        output_dir = FLAGS.output_dir

    tf.io.gfile.makedirs(output_dir)
    logging.info('Saving checkpoints at %s', output_dir)

    # Log Run Hypers
    hypers_dict = {
        'batch_size': FLAGS.batch_size,
        'base_learning_rate': FLAGS.base_learning_rate,
        'one_minus_momentum': FLAGS.one_minus_momentum,
        'l2': FLAGS.l2,
        'stddev_mean_init': FLAGS.stddev_mean_init,
        'stddev_stddev_init': FLAGS.stddev_stddev_init,
    }
    logging.info('Hypers:')
    logging.info(pprint.pformat(hypers_dict))

    # Initialize distribution strategy on flag-specified accelerator
    strategy = utils.init_distribution_strategy(FLAGS.force_use_cpu,
                                                FLAGS.use_gpu, FLAGS.tpu)
    use_tpu = not (FLAGS.force_use_cpu or FLAGS.use_gpu)

    # Only permit use of L2 regularization with a tied mean prior
    if FLAGS.l2 is not None and FLAGS.l2 > 0 and not FLAGS.tied_mean_prior:
        raise NotImplementedError(
            'For a principled objective, L2 regularization should not be used '
            'when the prior mean is untied from the posterior mean.')

    batch_size = FLAGS.batch_size * FLAGS.num_cores

    # Reweighting loss for class imbalance
    class_reweight_mode = FLAGS.class_reweight_mode
    if class_reweight_mode == 'constant':
        class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
    else:
        class_weights = None

    # Load in datasets.
    datasets, steps = utils.load_dataset(train_batch_size=batch_size,
                                         eval_batch_size=batch_size,
                                         flags=FLAGS,
                                         strategy=strategy)
    available_splits = list(datasets.keys())
    test_splits = [split for split in available_splits if 'test' in split]
    eval_splits = [
        split for split in available_splits
        if 'validation' in split or 'test' in split
    ]

    # Iterate eval datasets
    eval_datasets = {split: iter(datasets[split]) for split in eval_splits}
    dataset_train = datasets['train']
    train_steps_per_epoch = steps['train']
    train_dataset_size = train_steps_per_epoch * batch_size

    if FLAGS.use_bfloat16:
        tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

    summary_writer = tf.summary.create_file_writer(
        os.path.join(output_dir, 'summaries'))

    if FLAGS.prior_stddev is None:
        logging.info(
            'A fixed prior stddev was not supplied. Computing a prior stddev = '
            'sqrt(2 / fan_in) for each layer. This is recommended over providing '
            'a fixed prior stddev.')

    with strategy.scope():
        logging.info('Building Keras ResNet-50 Radial model.')
        model = None
        if FLAGS.load_from_checkpoint:
            initial_epoch, model = utils.load_keras_checkpoints(
                FLAGS.checkpoint_dir, load_ensemble=False, return_epoch=True)
        else:
            initial_epoch = 0
            model = ub.models.resnet50_radial(
                input_shape=utils.load_input_shape(dataset_train),
                num_classes=1,  # binary classification task
                prior_stddev=FLAGS.prior_stddev,
                dataset_size=train_dataset_size,
                stddev_mean_init=FLAGS.stddev_mean_init,
                stddev_stddev_init=FLAGS.stddev_stddev_init,
                tied_mean_prior=FLAGS.tied_mean_prior)
        utils.log_model_init_info(model=model)

        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate
        lr_decay_epochs = [
            (int(start_epoch_str) * FLAGS.train_epochs) // DEFAULT_NUM_EPOCHS
            for start_epoch_str in FLAGS.lr_decay_epochs
        ]
        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            train_steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = utils.get_diabetic_retinopathy_base_metrics(
            use_tpu=use_tpu,
            num_bins=FLAGS.num_bins,
            use_validation=FLAGS.use_validation,
            available_splits=available_splits)

        # Radial specific metrics
        metrics.update({
            'train/kl': tf.keras.metrics.Mean(),
            'train/kl_scale': tf.keras.metrics.Mean()
        })

        # TODO(nband): debug or remove
        # checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        # latest_checkpoint = tf.train.latest_checkpoint(output_dir)
        # if latest_checkpoint:
        #   # checkpoint.restore must be within a strategy.scope()
        #   # so that optimizer slot variables are mirrored.
        #   checkpoint.restore(latest_checkpoint)
        #   logging.info('Loaded checkpoint %s', latest_checkpoint)
        #   initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    # Define OOD metrics outside the accelerator scope for CPU eval.
    # This will cause an error on TPU.
    if not use_tpu:
        metrics.update(
            utils.get_diabetic_retinopathy_cpu_metrics(
                available_splits=available_splits,
                use_validation=FLAGS.use_validation))

    for test_split in test_splits:
        metrics.update(
            {f'{test_split}/ms_per_example': tf.keras.metrics.Mean()})

    # Initialize loss function based on class reweighting setting
    loss_fn = utils.get_diabetic_retinopathy_loss_fn(
        class_reweight_mode=class_reweight_mode, class_weights=class_weights)

    # * Prepare for Evaluation *

    # Get the wrapper function which will produce uncertainty estimates for
    # our choice of method and Y/N ensembling.
    uncertainty_estimator_fn = utils.get_uncertainty_estimator(
        'radial', use_ensemble=False, use_tf=True)

    # Wrap our estimator to predict probabilities (apply sigmoid on logits)
    eval_estimator = utils.wrap_retinopathy_estimator(
        model, use_mixed_precision=FLAGS.use_bfloat16, numpy_outputs=False)

    estimator_args = {'num_samples': FLAGS.num_mc_samples_eval}

    @tf.function
    def train_step(iterator):
        """Training step function."""
        def step_fn(inputs):
            """Per-replica step function."""
            images = inputs['features']
            labels = inputs['labels']

            # For minibatch class reweighting, initialize per-batch loss function
            if class_reweight_mode == 'minibatch':
                batch_loss_fn = utils.get_minibatch_reweighted_loss_fn(
                    labels=labels)
            else:
                batch_loss_fn = loss_fn

            with tf.GradientTape() as tape:
                if FLAGS.num_mc_samples_train > 1:
                    logits_list = []
                    for _ in range(FLAGS.num_mc_samples_train):
                        logits = model(images, training=True)
                        logits = tf.squeeze(logits, axis=-1)
                        if FLAGS.use_bfloat16:
                            logits = tf.cast(logits, tf.float32)

                        logits_list.append(logits)

                    # Logits dimension is (num_samples, batch_size).
                    logits_list = tf.stack(logits_list, axis=0)

                    probs_list = tf.nn.sigmoid(logits_list)
                    probs = tf.reduce_mean(probs_list, axis=0)
                    negative_log_likelihood = tf.reduce_mean(
                        batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1),
                                      y_pred=probs,
                                      from_logits=False))
                else:
                    # Single train step
                    logits = model(images, training=True)
                    if FLAGS.use_bfloat16:
                        logits = tf.cast(logits, tf.float32)

                    negative_log_likelihood = tf.reduce_mean(
                        batch_loss_fn(y_true=tf.expand_dims(labels, axis=-1),
                                      y_pred=logits,
                                      from_logits=True))
                    probs = tf.squeeze(tf.nn.sigmoid(logits))

                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the BN parameters and bias terms. This
                    # excludes only fast weight approximate posterior/prior parameters,
                    # but pay caution to their naming scheme.
                    if 'bn' in var.name or 'bias' in var.name:
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))
                kl = sum(model.losses)
                kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
                kl_scale /= train_steps_per_epoch * FLAGS.kl_annealing_epochs
                kl_scale = tf.minimum(1., kl_scale)
                kl_loss = kl_scale * kl

                loss = negative_log_likelihood + l2_loss + kl_loss

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/kl'].update_state(kl)
            metrics['train/kl_scale'].update_state(kl_scale)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auprc'].update_state(labels, probs)
            metrics['train/auroc'].update_state(labels, probs)

            if not use_tpu:
                metrics['train/ece'].add_batch(probs, label=labels)

        for _ in tf.range(tf.cast(train_steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    start_time = time.time()

    train_iterator = iter(dataset_train)
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch + 1)
        train_step(train_iterator)

        current_step = (epoch + 1) * train_steps_per_epoch
        max_steps = train_steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)

        # Run evaluation on all evaluation datasets, and compute metrics
        per_pred_results, total_results = utils.evaluate_model_and_compute_metrics(
            strategy,
            eval_datasets,
            steps,
            metrics,
            eval_estimator,
            uncertainty_estimator_fn,
            batch_size,
            available_splits,
            estimator_args=estimator_args,
            call_dataset_iter=False,
            is_deterministic=False,
            num_bins=FLAGS.num_bins,
            use_tpu=use_tpu,
            return_per_pred_results=True)

        # Optionally log to wandb
        if FLAGS.use_wandb:
            wandb.log(total_results, step=epoch)

        with summary_writer.as_default():
            for name, result in total_results.items():
                if result is not None:
                    tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            # checkpoint_name = checkpoint.save(
            #     os.path.join(output_dir, 'checkpoint'))
            # logging.info('Saved checkpoint to %s', checkpoint_name)

            # TODO(nband): debug checkpointing
            # Also save Keras model, due to checkpoint.save issue
            keras_model_name = os.path.join(output_dir,
                                            f'keras_model_{epoch + 1}')
            model.save(keras_model_name)
            logging.info('Saved keras model to %s', keras_model_name)

            # Save per-prediction metrics
            utils.save_per_prediction_results(output_dir,
                                              epoch + 1,
                                              per_pred_results,
                                              verbose=False)

    # final_checkpoint_name = checkpoint.save(
    #     os.path.join(output_dir, 'checkpoint'),)
    # logging.info('Saved last checkpoint to %s', final_checkpoint_name)

    keras_model_name = os.path.join(output_dir,
                                    f'keras_model_{FLAGS.train_epochs}')
    model.save(keras_model_name)
    logging.info('Saved keras model to %s', keras_model_name)

    # Save per-prediction metrics
    utils.save_per_prediction_results(output_dir,
                                      FLAGS.train_epochs,
                                      per_pred_results,
                                      verbose=False)

    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'l2': FLAGS.l2,
            'stddev_mean_init': FLAGS.stddev_mean_init,
            'stddev_stddev_init': FLAGS.stddev_stddev_init,
        })

    if wandb_run is not None:
        wandb_run.finish()
コード例 #13
0
def train(hyp, opt, device, tb_writer=None, wandb=None):
    logger.info(
        colorstr('hyperparameters: ') + ', '.join(f'{k}={v}'
                                                  for k, v in hyp.items()))
    save_dir, epochs, batch_size, total_batch_size, weights, rank = \
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

    # Directories
    wdir = save_dir / 'weights'
    wdir.mkdir(parents=True, exist_ok=True)  # make dir
    last = wdir / 'last.pt'
    best = wdir / 'best.pt'
    results_file = save_dir / 'results.txt'

    # Save run settings
    with open(save_dir / 'hyp.yaml', 'w') as f:
        yaml.dump(hyp, f, sort_keys=False)
    with open(save_dir / 'opt.yaml', 'w') as f:
        yaml.dump(vars(opt), f, sort_keys=False)

    # Configure
    plots = not opt.evolve  # create plots
    cuda = device.type != 'cpu'
    init_seeds(2 + rank)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.SafeLoader)  # data dict
    with torch_distributed_zero_first(rank):
        check_dataset(data_dict)  # check
    train_path = data_dict['train']
    test_path = data_dict['val']
    nc = 1 if opt.single_cls else int(data_dict['nc'])  # number of classes
    names = ['item'] if opt.single_cls and len(
        data_dict['names']) != 1 else data_dict['names']  # class names
    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (
        len(names), nc, opt.data)  # check

    # Model
    pretrained = weights.endswith('.pt')
    if pretrained:
        with torch_distributed_zero_first(rank):
            attempt_download(weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        if hyp.get('anchors'):
            ckpt['model'].yaml['anchors'] = round(
                hyp['anchors'])  # force autoanchor
        model = Model(opt.cfg or ckpt['model'].yaml, ch=3,
                      nc=nc).to(device)  # create
        exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [
        ]  # exclude keys
        state_dict = ckpt['model'].float().state_dict()  # to FP32
        state_dict = intersect_dicts(state_dict,
                                     model.state_dict(),
                                     exclude=exclude)  # intersect
        model.load_state_dict(state_dict, strict=False)  # load
        logger.info(
            'Transferred %g/%g items from %s' %
            (len(state_dict), len(model.state_dict()), weights))  # report
    else:
        model = Model(opt.cfg, ch=3, nc=nc).to(device)  # create

    # Freeze
    freeze = []  # parameter names to freeze (full or partial)
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            print('freezing %s' % k)
            v.requires_grad = False

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / total_batch_size),
                     1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= total_batch_size * accumulate / nbs  # scale weight_decay
    logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_modules():
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
            pg2.append(v.bias)  # biases
        if isinstance(v, nn.BatchNorm2d):
            pg0.append(v.weight)  # no decay
        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
            pg1.append(v.weight)  # apply decay

    if opt.adam:
        optimizer = optim.Adam(pg0,
                               lr=hyp['lr0'],
                               betas=(hyp['momentum'],
                                      0.999))  # adjust beta1 to momentum
    else:
        optimizer = optim.SGD(pg0,
                              lr=hyp['lr0'],
                              momentum=hyp['momentum'],
                              nesterov=True)

    optimizer.add_param_group({
        'params': pg1,
        'weight_decay': hyp['weight_decay']
    })  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' %
                (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # plot_lr_scheduler(optimizer, scheduler, epochs)

    # Logging
    if rank in [-1, 0] and wandb and wandb.run is None:
        opt.hyp = hyp  # add hyperparameters
        wandb_run = wandb.init(
            config=opt,
            resume="allow",
            project='YOLOv3-kaist-v028'
            if opt.project == 'runs/train' else Path(opt.project).stem,
            name=save_dir.stem,
            id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
    loggers = {'wandb': wandb}  # loggers dict

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # Results
        if ckpt.get('training_results') is not None:
            with open(results_file, 'w') as file:
                file.write(ckpt['training_results'])  # write results.txt

        # Epochs
        start_epoch = ckpt['epoch'] + 1
        if opt.resume:
            assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (
                weights, epochs)
        if epochs < start_epoch:
            logger.info(
                '%s has been trained for %g epochs. Fine-tuning for %g additional epochs.'
                % (weights, ckpt['epoch'], epochs))
            epochs += ckpt['epoch']  # finetune additional epochs

        del ckpt, state_dict

    # Image sizes
    gs = int(model.stride.max())  # grid size (max stride)
    nl = model.model[
        -1].nl  # number of detection layers (used for scaling hyp['obj'])
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size
                         ]  # verify imgsz are gs-multiples

    # DP mode
    if cuda and rank == -1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # SyncBatchNorm
    if opt.sync_bn and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        logger.info('Using SyncBatchNorm()')

    # EMA
    ema = ModelEMA(model) if rank in [-1, 0] else None

    # DDP mode
    if cuda and rank != -1:
        model = DDP(model,
                    device_ids=[opt.local_rank],
                    output_device=opt.local_rank)

    # Trainloader
    dataloader, dataset = create_dataloader(train_path,
                                            imgsz,
                                            batch_size,
                                            gs,
                                            opt,
                                            hyp=hyp,
                                            augment=True,
                                            cache=opt.cache_images,
                                            rect=opt.rect,
                                            rank=rank,
                                            world_size=opt.world_size,
                                            workers=opt.workers,
                                            image_weights=opt.image_weights,
                                            quad=opt.quad,
                                            prefix=colorstr('train: '))
    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
    nb = len(dataloader)  # number of batches
    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (
        mlc, nc, opt.data, nc - 1)

    # Process 0
    if rank in [-1, 0]:
        ema.updates = start_epoch * nb // accumulate  # set EMA updates
        testloader = create_dataloader(
            test_path,
            imgsz_test,
            total_batch_size,
            gs,
            opt,  # testloader
            hyp=hyp,
            cache=opt.cache_images and not opt.notest,
            rect=True,
            rank=-1,
            world_size=opt.world_size,
            workers=opt.workers,
            pad=0.5,
            prefix=colorstr('val: '))[0]

        if not opt.resume:
            labels = np.concatenate(dataset.labels, 0)
            c = torch.tensor(labels[:, 0])  # classes
            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
            # model._initialize_biases(cf.to(device))
            if plots:
                plot_labels(labels, save_dir, loggers)
                if tb_writer:
                    tb_writer.add_histogram('classes', c, 0)

            # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset,
                              model=model,
                              thr=hyp['anchor_t'],
                              imgsz=imgsz)

    # Model parameters
    hyp['box'] *= 3. / nl  # scale to layers
    hyp['cls'] *= nc / 80. * 3. / nl  # scale to classes and layers
    hyp['obj'] *= (imgsz / 640)**2 * 3. / nl  # scale to image size and layers
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # iou loss ratio (obj_loss = 1.0 or iou)
    model.class_weights = labels_to_class_weights(
        dataset.labels, nc).to(device) * nc  # attach class weights
    model.names = names

    # Start training
    t0 = time.time()
    nw = max(round(hyp['warmup_epochs'] * nb),
             1000)  # number of warmup iterations, max(3 epochs, 1k iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0
               )  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
    scheduler.last_epoch = start_epoch - 1  # do not move
    scaler = amp.GradScaler(enabled=cuda)
    logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
                f'Using {dataloader.num_workers} dataloader workers\n'
                f'Logging results to {save_dir}\n'
                f'Starting training for {epochs} epochs...')
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional)
        if opt.image_weights:
            # Generate indices
            if rank in [-1, 0]:
                cw = model.class_weights.cpu().numpy() * (
                    1 - maps)**2 / nc  # class weights
                iw = labels_to_image_weights(dataset.labels,
                                             nc=nc,
                                             class_weights=cw)  # image weights
                dataset.indices = random.choices(
                    range(dataset.n), weights=iw,
                    k=dataset.n)  # rand weighted idx
            # Broadcast if DDP
            if rank != -1:
                indices = (torch.tensor(dataset.indices)
                           if rank == 0 else torch.zeros(dataset.n)).int()
                dist.broadcast(indices, 0)
                if rank != 0:
                    dataset.indices = indices.cpu().numpy()

        # Update mosaic border
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = torch.zeros(4, device=device)  # mean losses
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        logger.info(
            ('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls',
                                   'total', 'targets', 'img_size'))
        if rank in [-1, 0]:
            pbar = tqdm(pbar, total=nb)  # progress bar
        optimizer.zero_grad()
        for i, (
                imgs, targets, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device, non_blocking=True).float(
            ) / 255.0  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, nbs / total_batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(ni, xi, [
                        hyp['warmup_bias_lr'] if j == 2 else 0.0,
                        x['initial_lr'] * lf(epoch)
                    ])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(
                            ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = F.interpolate(imgs,
                                         size=ns,
                                         mode='bilinear',
                                         align_corners=False)

            # Forward
            with amp.autocast(enabled=cuda):
                pred = model(imgs)  # forward
                loss, loss_items = compute_loss(
                    pred, targets.to(device),
                    model)  # loss scaled by batch_size
                if rank != -1:
                    loss *= opt.world_size  # gradient averaged between devices in DDP mode
                if opt.quad:
                    loss *= 4.

            # Backward
            scaler.scale(loss).backward()

            # Optimize
            if ni % accumulate == 0:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema:
                    ema.update(model)

            # Print
            if rank in [-1, 0]:
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9
                                 if torch.cuda.is_available() else 0)  # (GB)
                s = ('%10s' * 2 +
                     '%10.4g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem,
                                      *mloss, targets.shape[0], imgs.shape[-1])
                pbar.set_description(s)

                # Plot
                if plots and ni < 3:
                    f = save_dir / f'train_batch{ni}.jpg'  # filename
                    Thread(target=plot_images,
                           args=(imgs, targets, paths, f),
                           daemon=True).start()
                    # if tb_writer:
                    #     tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
                    #     tb_writer.add_graph(model, imgs)  # add model to tensorboard
                elif plots and ni == 3 and wandb:
                    wandb.log({
                        "Mosaics": [
                            wandb.Image(str(x), caption=x.name)
                            for x in save_dir.glob('train*.jpg')
                        ]
                    })

            # end batch ------------------------------------------------------------------------------------------------
        # end epoch ----------------------------------------------------------------------------------------------------

        # Scheduler
        lr = [x['lr'] for x in optimizer.param_groups]  # for tensorboard
        scheduler.step()

        # DDP process 0 or single-GPU
        if rank in [-1, 0]:
            # mAP
            if ema:
                ema.update_attr(model,
                                include=[
                                    'yaml', 'nc', 'hyp', 'gr', 'names',
                                    'stride', 'class_weights'
                                ])
            final_epoch = epoch + 1 == epochs
            if not opt.notest or final_epoch:  # Calculate mAP
                results, maps, times = test.test(
                    opt.data,
                    batch_size=total_batch_size,
                    imgsz=imgsz_test,
                    model=ema.ema,
                    single_cls=opt.single_cls,
                    dataloader=testloader,
                    save_dir=save_dir,
                    plots=plots and final_epoch,
                    log_imgs=opt.log_imgs if wandb else 0)

            # Write
            with open(results_file, 'a') as f:
                f.write(
                    s + '%10.4g' * 7 % results +
                    '\n')  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
            if len(opt.name) and opt.bucket:
                os.system('gsutil cp %s gs://%s/results/results%s.txt' %
                          (results_file, opt.bucket, opt.name))

            # Log
            tags = [
                'train/box_loss',
                'train/obj_loss',
                'train/cls_loss',  # train loss
                'metrics/precision',
                'metrics/recall',
                'metrics/mAP_0.5',
                'metrics/mAP_0.5:0.95',
                'val/box_loss',
                'val/obj_loss',
                'val/cls_loss',  # val loss
                'x/lr0',
                'x/lr1',
                'x/lr2'
            ]  # params
            for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
                if tb_writer:
                    tb_writer.add_scalar(tag, x, epoch)  # tensorboard
                if wandb:
                    wandb.log({tag: x})  # W&B

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # weighted combination of [P, R, [email protected], [email protected]]
            if fi > best_fitness:
                best_fitness = fi

            # Save model
            save = (not opt.nosave) or (final_epoch and not opt.evolve)
            if save:
                with open(results_file, 'r') as f:  # create checkpoint
                    ckpt = {
                        'epoch':
                        epoch,
                        'best_fitness':
                        best_fitness,
                        'training_results':
                        f.read(),
                        'model':
                        ema.ema,
                        'optimizer':
                        None if final_epoch else optimizer.state_dict(),
                        'wandb_id':
                        wandb_run.id if wandb else None
                    }

                # Save last, best and delete
                torch.save(ckpt, last)
                if best_fitness == fi:
                    torch.save(ckpt, best)
                del ckpt
        # end epoch ----------------------------------------------------------------------------------------------------
    # end training

    if rank in [-1, 0]:
        # Strip optimizers
        final = best if best.exists() else last  # final model
        for f in [last, best]:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
        if opt.bucket:
            os.system(f'gsutil cp {final} gs://{opt.bucket}/weights')  # upload

        # Plots
        if plots:
            plot_results(save_dir=save_dir)  # save as results.png
            if wandb:
                files = [
                    'results.png', 'precision_recall_curve.png',
                    'confusion_matrix.png'
                ]
                wandb.log({
                    "Results": [
                        wandb.Image(str(save_dir / f), caption=f)
                        for f in files if (save_dir / f).exists()
                    ]
                })
                if opt.log_artifacts:
                    wandb.log_artifact(artifact_or_path=str(final),
                                       type='model',
                                       name=save_dir.stem)

        # Test best.pt
        logger.info('%g epochs completed in %.3f hours.\n' %
                    (epoch - start_epoch + 1, (time.time() - t0) / 3600))
        if opt.data.endswith('kaist.yaml') and nc == 10:  # if COCO
            for conf, iou, save_json in ([0.25, 0.45,
                                          False], [0.001, 0.65,
                                                   True]):  # speed, mAP tests
                results, _, _ = test.test(opt.data,
                                          batch_size=total_batch_size,
                                          imgsz=imgsz_test,
                                          conf_thres=conf,
                                          iou_thres=iou,
                                          model=attempt_load(final,
                                                             device).half(),
                                          single_cls=opt.single_cls,
                                          dataloader=testloader,
                                          save_dir=save_dir,
                                          save_json=save_json,
                                          plots=False)

    else:
        dist.destroy_process_group()

    wandb.run.finish() if wandb and wandb.run else None
    torch.cuda.empty_cache()
    return results
コード例 #14
0
ファイル: run.py プロジェクト: we1l1n/tensor2struct-public
def eval_and_report(args, exp_config, model_config_args, logdir, infer_mod):
    model_config_file = exp_config["model_config"]

    summary = collections.defaultdict(float)
    for step in exp_config["eval_steps"]:
        infer_output_path = "{}/{}-step{}.infer".format(
            exp_config["eval_output"], exp_config["eval_name"], step)

        # eval_only will not run the infer
        if args.mode != "eval_only":
            infer_config = InferConfig(
                model_config_file,
                model_config_args,
                logdir,
                exp_config["eval_section"],
                exp_config["eval_beam_size"],
                infer_output_path,
                step,
                debug=exp_config["eval_debug"],
                method=exp_config["eval_method"],
            )

            try:
                infer_mod.main(infer_config)
            except infer.CheckpointNotFoundError as e:
                print(f"Infer error {str(e)}")
                continue

        eval_output_path = "{}/{}-step{}.eval".format(
            exp_config["eval_output"], exp_config["eval_name"], step)
        eval_config = EvalConfig(
            model_config_file,
            model_config_args,
            logdir,
            exp_config["eval_section"],
            infer_output_path,
            eval_output_path,
        )

        # etype
        if "eval_type" in exp_config:
            eval_config.etype = exp_config["eval_type"]
        else:
            assert eval_config.etype == "match"

        try:
            metrics = eval.main(eval_config)
        except infer.CheckpointNotFoundError as e:
            print(f"Eval error {str(e)}")
            continue

        # update some exp configs
        wandb.config.update({
            "eval_method": exp_config["eval_method"],
            "eval_section": exp_config["eval_section"],
            "eval_beam_size": exp_config["eval_beam_size"],
        })
        if "args" in exp_config:
            wandb.config.update({"exp_args": exp_config["args"]})

        # commit with step
        eval_section = exp_config["eval_section"]
        if "all" in metrics["total_scores"]:  # spider
            exact_match = metrics["total_scores"]["all"]["exact"]
            exec_match = metrics["total_scores"]["all"]["exec"]
            print(
                "Step: ",
                step,
                "\tmatch score,",
                exact_match,
                "\texe score:",
                exec_match,
            )
            wandb.log(
                {
                    f"{eval_section}_exact_match": exact_match,
                    f"{eval_section}_exe_acc": exec_match,
                },
                step=step,
            )

            if exact_match > summary[f"{eval_section}-best-exact_match"]:
                summary[f"{eval_section}-best-exact_match"] = exact_match
                summary[f"{eval_section}-best_exact_match_step"] = step
            if exec_match > summary[f"{eval_section}-best-exec_match"]:
                summary[f"{eval_section}-best-exec_match"] = exec_match
                summary[f"{eval_section}-best_exec_match_step"] = step
        else:  # wikisql, etc
            lf_accuracy = metrics["total_scores"]["lf_accuracy"]
            exe_accuracy = metrics["total_scores"]["exe_accuracy"]
            wandb.log(
                {f"{eval_section}-lf-accuracy": lf_accuracy},
                step=step,
            )
            wandb.log(
                {f"{eval_section}-exe-accuracy": exe_accuracy},
                step=step,
            )
            print(step, metrics["total_scores"])

            if lf_accuracy > summary[f"{eval_section}-best-lf-accuracy"]:
                summary[f"{eval_section}-best-lf-accuracy"] = lf_accuracy
                summary[f"{eval_section}-best_lf_accuracy_step"] = step
            if exe_accuracy > summary[f"{eval_section}-best-exe-accuracy"]:
                summary[f"{eval_section}-best-exe-accuracy"] = exe_accuracy
                summary[f"{eval_section}-best_exe_accuracy_step"] = step

    # sync summary to wandb
    print("Summary:", str(summary))
    for item in summary:
        wandb.run.summary[item] = summary[item]
コード例 #15
0
def run(config):

    run_config = config['run_config']
    model_config = config['model_config']
    param_config = config['param_config']
    data_config = config['data_config']
    log_config = config['log_config']

    if log_config['wandb']:
        wandb.init(project="pmnist", name=log_config['wandb_name'])
        wandb.config.update(config)

    # Reproducibility
    seed = run_config['seed']
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Loss
    criterion = nn.CrossEntropyLoss()

    # Model
    net = getattr(model, model_config['arch']).Model(model_config)
    net.to(run_config['device'])
    net.apply(initialize_weights)

    # Data
    Dataset = getattr(datasets, data_config['dataset'])

    # Training
    memories = []
    validloaders = []
    s = 0

    for task_id, task in enumerate(run_config['tasks'], 0):

        validset = Dataset(dset='test',
                           valid=data_config['valid'],
                           transform=data_config['test_transform'],
                           task=task)
        validloaders.append(
            DataLoader(validset,
                       batch_size=param_config['batch_size'],
                       shuffle=False,
                       pin_memory=True,
                       num_workers=data_config['num_workers']))
        trainset = Dataset(dset='train',
                           valid=data_config['valid'],
                           transform=data_config['train_transform'],
                           task=task)

        bufferloader = MultiLoader([trainset] + memories,
                                   batch_size=param_config['batch_size'])

        optimizer = torch.optim.SGD(
            net.parameters(),
            lr=param_config['model_lr'],
        )
        train = Train(optimizer, criterion, bufferloader, config)

        d_net = copy.deepcopy(net)
        d_trainloader = DataLoader(
            trainset,
            batch_size=param_config['distill_batch_size'],
            shuffle=True,
            pin_memory=True,
            num_workers=data_config['num_workers'])
        d_validloader = DataLoader(
            validset,
            batch_size=param_config['distill_batch_size'],
            shuffle=False,
            pin_memory=True,
            num_workers=data_config['num_workers'])

        if param_config['steps'] == 'epoch':
            steps = len(bufferloader) * param_config['no_steps']
        elif param_config['steps'] == 'minibatch':
            steps = param_config['no_steps']
        else:
            raise ValueError

        for step in range(steps):

            buffer_loss, buffer_accuracy = train(net)

            if int(steps * 0.05) <= 0 or step % int(steps * 0.05) == int(
                    steps * 0.05) - 1 or step == 0:

                valid_m = {'Test accuracy avg': 0}
                for i, vl in enumerate(validloaders):
                    test_loss, test_accuracy = test(net, criterion, vl,
                                                    run_config)
                    valid_m = {
                        **valid_m,
                        **{
                            f'Test loss {i}': test_loss,
                            f'Test accuracy {i}': test_accuracy,
                        }
                    }
                    valid_m['Test accuracy avg'] += (test_accuracy /
                                                     len(validloaders))

                train_m = {
                    f'Buffer loss': buffer_loss,
                    f'Buffer accuracy': buffer_accuracy,
                    f'Step': s
                }
                s += 1
                if log_config['print']:
                    print({**valid_m, **train_m})
                if log_config['wandb']:
                    wandb.log({**valid_m, **train_m})

        if task_id == len(run_config['tasks']) - 1:
            break

        if param_config['buffer_size'] != 0:
            buffer = None
            for b in range(param_config['buffer_size']):
                buff = None

                for c in range(model_config['n_classes']):
                    ds = list(
                        filter(
                            lambda x: x[1] == c,
                            Dataset(dset='train',
                                    valid=data_config['valid'],
                                    transform=data_config['train_transform'],
                                    task=task)))
                    buff = Buffer(ds,
                                  1) if buff is None else buff + Buffer(ds, 1)

                net.drop = nn.Dropout(0.0)
                buff, _ = distill(d_net, buff, config, criterion,
                                  d_trainloader, d_validloader, task_id)
                net.drop = nn.Dropout(model_config['dropout'])

                buffer = buff if buffer is None else buffer + buff

            if param_config['buffer_size'] == -1:
                memories.append(
                    Dataset(dset='train',
                            valid=data_config['valid'],
                            transform=data_config['train_transform'],
                            task=task))
            else:
                memories.append(buffer)
コード例 #16
0
    def step(self):
        config = self.config

        # Store transitions in the buffer
        transitions = self.actor.step()
        experiences = []
        for state, action, reward, next_state, next_action, done, info in transitions:
            #             self.record_online_return(info)
            self.total_steps += 1
            reward = config.reward_normalizer(reward)
            experiences.append(
                [state, action, reward, next_state, next_action, done])
        self.replay.feed_batch(experiences)

        # Start updating network parameters after exploration_steps
        if self.total_steps > self.config.exploration_steps:

            # Getting samples from buffer
            experiences = self.replay.sample()
            states, actions, rewards, next_states, next_actions, terminals = experiences
            states = self.config.state_normalizer(states)
            next_states = self.config.state_normalizer(next_states)

            # Estimate targets
            with torch.no_grad():
                _, psi_next, _ = self.network(next_states)

            if self.config.double_q:
                best_actions = torch.argmax(self.network(next_states), dim=-1)
                q_next = q_next[self.batch_indices, best_actions]
            else:
                next_actions = tensor(next_actions).long()
                psi_next = psi_next[
                    self.batch_indices,
                    next_actions, :]  # TODO: double check dims here

            terminals = tensor(terminals)
            psi_next = self.config.discount * psi_next * (
                1 - terminals.unsqueeze(1).repeat(1, psi_next.shape[1]))
            phi, psi, _ = self.network(states)
            psi_next.add_(phi)  # TODO: double chec this

            # Computing estimates
            actions = tensor(actions).long()
            psi = psi[self.batch_indices, actions, :]

            #             loss_psi = (psi_next - psi).pow(2).mul(0.5).mean(0)
            loss_psi = (psi_next - psi).pow(2).mul(0.5).mean()

            loss = loss_psi

            total_loss = loss.mean()
            self.loss_vec.append(total_loss.item())
            self.loss_psi_vec.append(total_loss.item())

            if (self.is_wb):
                wandb.log({
                    "steps_loss": self.total_steps,
                    "loss": loss.item(),
                    "loss_psi": loss_psi.item(),
                    "loss_q": loss_q.item()
                })

            self.optimizer.zero_grad()
            #             loss.backward(torch.ones(loss.shape))
            loss.backward()

            nn.utils.clip_grad_norm_(self.network.parameters(),
                                     self.config.gradient_clip)

            with config.lock:
                self.optimizer.step()
コード例 #17
0
def distill(model, buffer, config, criterion, train_loader, valid_loader, id):
    model = copy.deepcopy(model)

    run_config = config['run_config']
    param_config = config['param_config']
    log_config = config['log_config']

    model.train()
    eval_trainloader = copy.deepcopy(train_loader)

    buff_imgs, buff_trgs = next(
        iter(DataLoader(buffer, batch_size=len(buffer))))
    buff_imgs, buff_trgs = buff_imgs.to(run_config['device']), buff_trgs.to(
        run_config['device'])

    buff_imgs.requires_grad = True

    init_valid = DataLoader(ModelInitDataset(model, 10),
                            batch_size=1,
                            collate_fn=lambda x: x)
    init_loader = DataLoader(ModelInitDataset(model, -1),
                             batch_size=1,
                             collate_fn=lambda x: x)
    init_iter = iter(init_loader)

    buff_opt = torch.optim.SGD(
        [buff_imgs],
        lr=param_config['meta_lr'],
    )

    lr_list = []
    lr_opts = []
    for _ in range(param_config['inner_steps']):
        lr = np.log(np.exp([param_config['model_lr']]) - 1)
        lr = torch.tensor(lr, requires_grad=True, device=run_config['device'])
        lr_list.append(lr)
        lr_opts.append(torch.optim.SGD(
            [lr],
            param_config['lr_lr'],
        ))

    for i in range(param_config['outer_steps']):
        for step, (ds_imgs, ds_trgs) in enumerate(train_loader):
            try:
                init_batch = next(init_iter)
            except StopIteration:
                init_iter = iter(init_loader)
                init_batch = next(init_iter)

            ds_imgs = ds_imgs.to(run_config['device'])
            ds_trgs = ds_trgs.to(run_config['device'])

            acc_loss = None
            epoch_loss = [None for _ in range(param_config['inner_steps'])]

            for r, sigma in enumerate(init_batch):
                model.load_state_dict(sigma)
                model_opt = torch.optim.SGD(
                    model.parameters(),
                    lr=1,
                )
                with higher.innerloop_ctx(model,
                                          model_opt) as (fmodel, diffopt):
                    for j in range(param_config['inner_steps']):

                        buff_out = fmodel(buff_imgs)
                        buff_loss = criterion(buff_out, buff_trgs)
                        buff_loss = buff_loss * torch.log(
                            1 + torch.exp(lr_list[j]))
                        diffopt.step(buff_loss)

                        ds_out = fmodel(ds_imgs)
                        ds_loss = criterion(ds_out, ds_trgs)

                        epoch_loss[j] = epoch_loss[j] + ds_loss if epoch_loss[
                            j] is not None else ds_loss
                        acc_loss = acc_loss + ds_loss if acc_loss is not None else ds_loss


                        if (((step + i * len(train_loader)) % int(round(len(train_loader) * param_config['outer_steps'] * 0.05)) == \
                                int(round(len(train_loader) * param_config['outer_steps'] * 0.05)) - 1) or (step + i * len(train_loader)) == 0) \
                                and j == param_config['inner_steps'] - 1 and r == 0:

                            lrs = [
                                np.log(np.exp(lr.item()) + 1) for lr in lr_list
                            ]
                            lrs_log = {
                                f'Learning rate {i} - {id}': lr
                                for (i, lr) in enumerate(lrs)
                            }
                            train_loss, train_accuracy = test_distill(
                                init_valid, lrs, [buff_imgs, buff_trgs], model,
                                criterion, eval_trainloader, run_config)
                            test_loss, test_accuracy = test_distill(
                                init_valid, lrs, [buff_imgs, buff_trgs], model,
                                criterion, valid_loader, run_config)
                            metrics = {
                                f'Distill train loss {id}': train_loss,
                                f'Distill train accuracy {id}': train_accuracy,
                                f'Distill test loss {id}': test_loss,
                                f'Distill test accuracy {id}': test_accuracy,
                                f'Distill step {id}':
                                step + i * len(train_loader)
                            }

                            if log_config['wandb']:
                                wandb.log({**metrics, **lrs_log})

                            if log_config['print']:
                                print(metrics)

            # Update the lrs
            for j in range(param_config['inner_steps']):
                lr_opts[j].zero_grad()
                grad, = autograd.grad(epoch_loss[j],
                                      lr_list[j],
                                      retain_graph=True)
                lr_list[j].grad = grad
                lr_opts[j].step()

            buff_opt.zero_grad()
            acc_loss.backward()
            buff_opt.step()

    aux = []
    buff_imgs, buff_trgs = buff_imgs.detach().cpu(), buff_trgs.detach().cpu()
    for i in range(buff_imgs.size(0)):
        aux.append([buff_imgs[i], buff_trgs[i]])
    lr_list = [np.log(1 + np.exp(lr.item())) for lr in lr_list]

    return Buffer(
        aux,
        len(aux),
    ), lr_list
コード例 #18
0
def train(
    model: object,
    epoch,
    gradient_clip,
    learning_rate,
    train_loader: torch.utils.data.dataloader.DataLoader,
    valid_loader: torch.utils.data.dataloader.DataLoader,
    early_stopping_threshold: int = 10,
    early_stopping: bool = True,
) -> object:
    """
    :param model:  Torch model
    :param train_loader:  Training Data Folder
    :param valid_loader: Validation Data Folder
    :param learning_rate: Learning rate to improve loss function
    :param epoch: Number of times to pass though the entire data folder
    :param gradient_clip:
    :param early_stopping_threshold:  threshold to stop running model
    :param early_stopping: Bool to indicate early stopping

    :return: a model object
    """

    if early_stopping:
        stopping = EarlyStopping(threshold=early_stopping_threshold,
                                 verbose=True)

    wandb.watch(model)

    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    if torch.cuda.is_available():
        model.cuda()

    counter = 0

    for e in range(epoch):
        train_loss_list, val_loss_list, train_acc_list, val_acc_list = [], [], [], []

        for train_inputs, train_labels in train_loader:
            counter += 1
            model.init_hidden()

            # train_inputs = train_inputs.view(256, 1, -1, 216).float()

            if torch.cuda.is_available():
                train_inputs, train_labels = train_inputs.cuda(
                ), train_labels.cuda()

            # _, train_pred = torch.max(torch.sigmoid(train_output), 1)

            model.zero_grad()
            train_output = model(train_inputs)

            train_acc, train_f1, train_pr, train_rc = _metric_summary(
                pred=torch.max(train_output, dim=1).indices.data.cpu().numpy(),
                label=train_labels.cpu().numpy(),
            )

            train_loss = criterion(train_output, train_labels)
            train_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
            optimizer.step()

            train_loss_list.append(train_loss.item())
            train_acc_list.append(train_acc)

        log_scalar(name="Accuracy/train", value=train_acc, step=e)
        log_scalar(name="Precision/train", value=train_rc, step=e)
        log_scalar(name="F1/train", value=train_f1, step=e)
        log_scalar(name="Recall/train", value=train_rc, step=e)
        log_scalar(name="Loss/train", value=train_loss.item(), step=e)

        model.init_hidden()
        model.eval()

        for val_inputs, val_labels in valid_loader:

            if torch.cuda.is_available():
                val_inputs, val_labels = val_inputs.cuda(), val_labels.cuda()

            val_output = model(val_inputs)
            _, val_pred = torch.max(val_output, 1)

            val_loss = criterion(val_output, val_labels)
            val_loss_list.append(val_loss.item())

            val_acc, val_f1, val_pr, val_rc = _metric_summary(
                pred=torch.max(val_output, dim=1).indices.data.cpu().numpy(),
                label=val_labels.cpu().numpy(),
            )

            val_acc_list.append(val_acc)

        wandb.log({"Accuracy/val": val_acc}, step=e)
        wandb.log({"Precision/val": val_pr}, step=e)
        wandb.log({"Recall/val": val_rc}, step=e)
        wandb.log({"Loss/val": val_loss.item()}, step=e)
        log_scalar(name="F1/val", value=val_f1, step=e)

        model.train()
        _logger.info("Epoch: {}/{}..."
                     "Training Loss: {:.3f}..."
                     "Validation Loss: {:.3f}..."
                     "Train Accuracy: {:.3f}..."
                     "Test Accuracy: {:.3f}".format(
                         e + 1,
                         epoch,
                         np.mean(train_loss_list),
                         np.mean(val_loss_list),
                         np.mean(train_acc_list),
                         np.mean(val_acc_list),
                     ))

        stopping(val_loss=val_loss, model=model)
        if stopping.early_stop:
            _logger.info("Stopping Model Early")
            break

    wandb.sklearn.plot_confusion_matrix(
        val_labels.cpu().numpy(),
        val_pred.cpu().numpy(),
        valid_loader.dataset.classes,
    )

    _logger.info("Done Training, uploaded model to {}".format(wandb.run.dir))
    return model
コード例 #19
0
    "end": [1, 1.5, 1]
}, {
    "start": [1, 1, 1],
    "end": [1, 1, 1.5]
}, {
    "start": [1, 1, 1],
    "end": [1.2, 1.5, 1.5]
}]

vectors_2 = [{
    "start": [2, 2, 2],
    "end": [1, 1.5, 1],
    "color": [255, 255, 0]
}, {
    "start": [2, 2, 2],
    "end": [1, 1, 1.5],
    "color": [255, 255, 0],
}, {
    "start": [2, 2, 2],
    "end": [1.2, 1.5, 1.5],
    "color": [255, 255, 0]
}]

vectors_all = vectors + vectors_2

wandb.log({
    "separate_vectors": [make_scene([v]) for v in vectors],
    "color_vectors": make_scene(vectors_2),
    "all_vectors": make_scene(vectors_all)
})
コード例 #20
0
    val_acc_per_epoch = history.history['val_loss']
    best_epoch = val_acc_per_epoch.index(min(val_acc_per_epoch)) + 1
    print('Best epoch: %d' % (best_epoch,))

    # Fit best model
    final_model = LSTMIMO.fit(global_inputs_X,global_inputs_T,epochs=best_epoch,batch_size=BATCHSIZE,verbose=0)

    metrics = pd.DataFrame(columns=['mae','mape', 'rmse', 'B'], index=range(28))
    for i,df in enumerate(datasets):
        concat_input = tf.concat([dX_test[i]['X'],dX_test[i]['X2']], axis=2)
        FD_predictions = LSTMIMO.predict(concat_input)
        FD_eval_df = create_evaluation_df(FD_predictions, dX_test[i], HORIZON, dX_scaler[i])
        mae = validation(FD_eval_df['prediction'], FD_eval_df['actual'], 'MAE')
        mape = validation(FD_eval_df['prediction'], FD_eval_df['actual'], 'MAPE')
        rmse = validation(FD_eval_df['prediction'], FD_eval_df['actual'], 'RMSE')
        #print('rmse {}'.format(rmse))
        metrics.loc[i] = pd.Series({'mae':mae, 'mape':mape, 'rmse':rmse, 'B': names[i]})
    wandb.log({"mape": metrics.mape.mean()})
    wandb.log({"rmse": metrics.rmse.mean()})
    wandb.log({"mae": metrics.mae.mean()})
    if HORIZON == 72:
        metrics.to_csv('./results/'+dset+'/global/3days/LSTM_'+wandb.run.name+'.csv')
        model_path = '.models/'+dset+'_models/global_'+wandb.run.name
        save_model(LSTMIMO, model_path)
    if HORIZON == 24:
        metrics.to_csv('./results/'+dset+'/global/dayahead/LSTM_'+wandb.run.name+'.csv')
        model_path = '.models/'+dset+'_models/global_'+wandb.run.name
        save_model(LSTMIMO, model_path)
    run.finish()
コード例 #21
0
def main():  # pylint:disable=too-many-locals, too-many-statements
    """Runs everything"""
    X_train_, y_train_, X_test, y_test = get_initial_split(df_full_factorial_feat, y)

    for train_size in TRAIN_SIZES:
        for i in range(REPEAT):
            X_train, _, y_train, _ = train_test_split(X_train_, y_train_, train_size=train_size)

            # Train coregionalized model
            wandb.init(project='dispersant_screener', tags=['coregionalized', 'matern32'], reinit=True)
            m = build_coregionalized_model(X_train, y_train)
            m.optimize_restarts(20)
            y0, var0 = predict_coregionalized(m, X_test, 0)
            y1, var1 = predict_coregionalized(m, X_test, 1)
            metrics_0 = get_metrics(y0, y_test[:, 0])
            metrics_0 = add_postfix_to_keys(metrics_0, 0)

            metrics_1 = get_metrics(y1, y_test[:, 1])
            metrics_1 = add_postfix_to_keys(metrics_0, 1)

            variance_0 = get_variance_descriptors(var0)
            variance_1 = get_variance_descriptors(var1)
            variance_0 = add_postfix_to_keys(variance_0, 0)
            variance_1 = add_postfix_to_keys(variance_1, 1)

            overall_metrics = metrics_0
            overall_metrics.update(metrics_1)
            overall_metrics.update(variance_0)
            overall_metrics.update(variance_1)
            overall_metrics['train_size'] = len(X_train)
            overall_metrics['coregionalized'] = True

            METRICS.append(overall_metrics)

            plot_parity([(y0, y_test[:, 0], var0), (y1, y_test[:, 1], var1)],
                        'coregionalized_{}_{}.pdf'.format(len(X_train), i))
            wandb.log(overall_metrics)
            wandb.join()

            # Train "simple models"
            wandb.init(project='dispersant_screener', tags=['matern32'], reinit=True)
            m0 = build_model(X_train, y_train, 0)
            m0.optimize_restarts(20)
            m1 = build_model(X_train, y_train, 1)
            m1.optimize_restarts(20)

            y0, var0 = predict(m0, X_test)
            y1, var1 = predict(m1, X_test)
            metrics_0 = get_metrics(y0, y_test[:, 0])
            metrics_0 = add_postfix_to_keys(metrics_0, 0)

            metrics_1 = get_metrics(y1, y_test[:, 1])
            metrics_1 = add_postfix_to_keys(metrics_0, 1)

            variance_0 = get_variance_descriptors(var0)
            variance_1 = get_variance_descriptors(var1)
            variance_0 = add_postfix_to_keys(variance_0, 0)
            variance_1 = add_postfix_to_keys(variance_1, 1)

            overall_metrics = metrics_0
            overall_metrics.update(metrics_1)
            overall_metrics.update(variance_0)
            overall_metrics.update(variance_1)
            overall_metrics['train_size'] = len(X_train)
            overall_metrics['coregionalized'] = False

            METRICS.append(overall_metrics)

            plot_parity([(y0, y_test[:, 0], var0), (y1, y_test[:, 1], var1)],
                        'simple_{}_{}.pdf'.format(len(X_train), i))

            wandb.log(overall_metrics)
            wandb.join()

    df = pd.DataFrame(METRICS)
    df.to_csv('metrics.csv')
コード例 #22
0
    # model.load_state_dict(saved_checkpoint)
    model = torch.load(args.checkpoint_path).cuda()
    # device = torch.device('cuda:0')
    # model.to(device)
    # model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss()
    val_accuracy = 0

    model.eval()
    with torch.no_grad():
        loss_total = 0
        correct = 0
        total_sample = 0

        for batch in tqdm(test_dataloader):
            encoded_text, _, target, _ = batch

            outputs = model(encoded_text.cuda())
            loss = criterion(outputs, target.cuda())
            _, predicted = torch.max(outputs.data, 1)
            correct += (np.array(predicted.cpu()) == np.array(
                target.cpu())).sum()
            loss_total += loss.item() * encoded_text.shape[0]
            total_sample += encoded_text.shape[0]

        acc_total = correct / total_sample
        loss_total = loss_total / total_sample
        wandb.log({'Test loss': loss_total})
        wandb.log({'Test accuracy': acc_total})
        print('Test accuracy ', acc_total)
コード例 #23
0
ファイル: train.py プロジェクト: phamdinhkhanh/PFLD-pytorch
def main(args):
    # Step 1: parse args config
    logging.basicConfig(
        format=
        '[%(asctime)s] [p%(process)s] [%(pathname)s:%(lineno)d] [%(levelname)s] %(message)s',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(args.log_file, mode='w'),
            logging.StreamHandler()
        ])
    print_args(args)

    # Step 2: model, criterion, optimizer, scheduler
    if wandb.config.pfld_backbone == "GhostNet":
        plfd_backbone = CustomizedGhostNet(width=wandb.config.ghostnet_width, dropout=0.2)
        logger.info(f"Using GHOSTNET with width={wandb.config.ghostnet_width} as backbone of PFLD backbone")

        # If using pretrained weight from ghostnet model trained on image net
        if (wandb.config.ghostnet_with_pretrained_weight_image_net == True):
            logger.info(f"Using pretrained weights of ghostnet model trained on image net data ")
            plfd_backbone = load_pretrained_weight_imagenet_for_ghostnet_backbone(
                plfd_backbone, "./checkpoint_imagenet/state_dict_93.98.pth")
            


    else:
        plfd_backbone = PFLDInference().to(device) # MobileNet2 defaut
        logger.info("Using MobileNet2 as backbone of PFLD backbone")

    auxiliarynet = AuxiliaryNet().to(device)

    # Watch model by wandb
    wandb.watch(plfd_backbone)
    wandb.watch(auxiliarynet)

    criterion = PFLDLoss()
    optimizer = torch.optim.Adam(
        [{
            'params': plfd_backbone.parameters()
        }, {
            'params': auxiliarynet.parameters()
        }],
        lr=args.base_lr,
        weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=args.lr_patience, verbose=True)

    # step 3: data
    # argumetion
    transform = transforms.Compose([transforms.ToTensor()])
    wlfwdataset = WLFWDatasets(args.dataroot, transform)
    dataloader = DataLoader(
        wlfwdataset,
        batch_size=args.train_batchsize,
        shuffle=True,
        num_workers=args.workers,
        drop_last=False)

    wlfw_val_dataset = WLFWDatasets(args.val_dataroot, transform)
    wlfw_val_dataloader = DataLoader(
        wlfw_val_dataset,
        batch_size=args.val_batchsize,
        shuffle=False,
        num_workers=args.workers)

    # step 4: run
    writer = SummaryWriter(args.tensorboard)
    for epoch in range(args.start_epoch, args.end_epoch + 1):
        weighted_train_loss, train_loss = train(dataloader, plfd_backbone, auxiliarynet,
                                      criterion, optimizer, epoch)
        filename = os.path.join(
            str(args.snapshot), "checkpoint_epoch_" + str(epoch) + '.pth.tar')
        save_checkpoint({
            'epoch': epoch,
            'plfd_backbone': plfd_backbone.state_dict(),
            'auxiliarynet': auxiliarynet.state_dict()
        }, filename)

        val_loss = validate(wlfw_val_dataloader, plfd_backbone, auxiliarynet,
                            criterion)
        
        wandb.log({"metric/val_loss": val_loss})

        scheduler.step(val_loss)
        writer.add_scalar('data/weighted_loss', weighted_train_loss, epoch)
        writer.add_scalars('data/loss', {'val loss': val_loss, 'train loss': train_loss}, epoch)
    writer.close()
コード例 #24
0
def learn(*, policy, env, eval_env, nsteps, total_timesteps, ent_coef, lr,
             vf_coef=0.5,  max_grad_norm=0.5, gamma=0.99, lam=0.95,
            log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
            save_interval=0, load_path=None):
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    mpi_size = comm.Get_size()

    #tf.compat.v1.disable_v2_behavior()
    sess = tf.compat.v1.get_default_session()

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)
    
    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps
    
    nbatch_train = nbatch // nminibatches

    model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
                    nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
                    max_grad_norm=max_grad_norm)

    utils.load_all_params(sess)

    runner = Runner(env=env, eval_env=eval_env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)

    epinfobuf10 = deque(maxlen=10)
    epinfobuf100 = deque(maxlen=100)
    eval_epinfobuf100 = deque(maxlen=100)
    tfirststart = time.time()
    active_ep_buf = epinfobuf100
    eval_active_ep_buf = eval_epinfobuf100

    nupdates = total_timesteps//nbatch
    mean_rewards = []
    datapoints = []

    run_t_total = 0
    train_t_total = 0

    can_save = False
    checkpoints = [32, 64]
    saved_key_checkpoints = [False] * len(checkpoints)

    if Config.SYNC_FROM_ROOT and rank != 0:
        can_save = False

    def save_model(base_name=None):
        base_dict = {'datapoints': datapoints}
        utils.save_params_in_scopes(sess, ['model'], Config.get_save_file(base_name=base_name), base_dict)

    # For logging purposes, allow restoring of update
    start_update = 0
    if Config.RESTORE_STEP is not None:
        start_update = Config.RESTORE_STEP // nbatch

    z_iter = 0
    curr_z = np.random.randint(0, high=Config.POLICY_NHEADS)
    tb_writer = TB_Writer(sess)
    import os
    os.environ["WANDB_API_KEY"] = "02e3820b69de1b1fcc645edcfc3dd5c5079839a1"
    group_name = "%s__%s__%d__%d__%f__%d" %(Config.ENVIRONMENT,Config.RUN_ID,Config.CLUSTER_T,Config.N_KNN, Config.TEMP, Config.N_SKILLS)
    name = "%s__%s__%d__%d__%f__%d__%d" %(Config.ENVIRONMENT,Config.RUN_ID,Config.CLUSTER_T,Config.N_KNN,  Config.TEMP, Config.N_SKILLS, np.random.randint(100000000))
    wandb.init(project='ising_generalization' if Config.ENVIRONMENT == 'ising' else 'procgen_generalization' , entity='ssl_rl', config=Config.args_dict, group=group_name, name=name, mode="disabled" if Config.DISABLE_WANDB else "online")
    for update in range(start_update+1, nupdates+1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)

        # if Config.CUSTOM_REP_LOSS:
        #     params = tf.compat.v1.trainable_variables()
        #     source_params = [p for p in params if p.name in model.train_model.RL_enc_param_names]
        #     for i in range(1,Config.POLICY_NHEADS):
        #         target_i_params = [p for p in params if p.name in model.train_model.target_enc_param_names[i]]
        #         soft_update(source_params,target_i_params,tau=0.95)

        mpi_print('collecting rollouts...')
        run_tstart = time.time()
        # if z_iter < 4: # 8 epochs / skill
        #     z_iter += 1
        # else:
        #     # sample new skill for current episodes
        #     curr_z = np.random.randint(0, high=Config.POLICY_NHEADS)
        #     model.head_idx_current_batch = curr_z
        #     z_iter = 0

        packed = runner.run(update_frac=update/nupdates)
    
        obs, returns, masks, actions, values, neglogpacs, infos, rewards, epinfos, eval_epinfos = packed
        values_i = returns_i = states_nce = anchors_nce = labels_nce = actions_nce = neglogps_nce = rewards_nce = infos_nce = None
    
        # reshape our augmented state vectors to match first dim of observation array
        # (mb_size*num_envs, 64*64*RGB)
        # (mb_size*num_envs, num_actions)
        avg_value = np.mean(values)
        epinfobuf10.extend(epinfos)
        epinfobuf100.extend(epinfos)
        eval_epinfobuf100.extend(eval_epinfos)

        run_elapsed = time.time() - run_tstart
        run_t_total += run_elapsed
        mpi_print('rollouts complete')

        mblossvals = []

        mpi_print('updating parameters...')
        train_tstart = time.time()

        mean_cust_loss = 0
        inds = np.arange(nbatch)
        inds_nce = np.arange(nbatch//runner.nce_update_freq)
        for _ in range(noptepochs):
            np.random.shuffle(inds)
            np.random.shuffle(inds_nce)
            for start in range(0, nbatch, nbatch_train):
                sess.run([model.train_model.train_dropout_assign_ops])
                end = start + nbatch_train
                mbinds = inds[start:end]

                
                slices = (arr[mbinds] for arr in (obs, returns, masks, actions, infos, values, neglogpacs, rewards))
                
                mblossvals.append(model.train(lrnow, cliprangenow, *slices, train_target='policy'))
                slices = (arr[mbinds] for arr in (obs, returns, masks, actions, infos, values, neglogpacs, rewards))
                model.train(lrnow, cliprangenow, *slices, train_target='encoder')
                slices = (arr[mbinds] for arr in (obs, returns, masks, actions, infos, values, neglogpacs, rewards))
                model.train(lrnow, cliprangenow, *slices, train_target='latent')
        # update the dropout mask
        sess.run([model.train_model.train_dropout_assign_ops])
        sess.run([model.train_model.run_dropout_assign_ops])

        train_elapsed = time.time() - train_tstart
        train_t_total += train_elapsed
        mpi_print('update complete')

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))

        if update % log_interval == 0 or update == 1:
            step = update*nbatch
            eval_rew_mean = utils.process_ep_buf(eval_active_ep_buf, tb_writer=tb_writer, suffix='_eval', step=step)
            rew_mean_10 = utils.process_ep_buf(active_ep_buf, tb_writer=tb_writer, suffix='', step=step)
            
            ep_len_mean = np.nanmean([epinfo['l'] for epinfo in active_ep_buf])
            
            mpi_print('\n----', update)

            mean_rewards.append(rew_mean_10)
            datapoints.append([step, rew_mean_10])
            tb_writer.log_scalar(ep_len_mean, 'ep_len_mean', step=step)
            tb_writer.log_scalar(fps, 'fps', step=step)
            tb_writer.log_scalar(avg_value, 'avg_value', step=step)
            tb_writer.log_scalar(mean_cust_loss, 'custom_loss', step=step)


            mpi_print('time_elapsed', tnow - tfirststart, run_t_total, train_t_total)
            mpi_print('timesteps', update*nsteps, total_timesteps)

            # eval_rew_mean = episode_rollouts(eval_env,model,step,tb_writer)

            mpi_print('eplenmean', ep_len_mean)
            mpi_print('eprew', rew_mean_10)
            mpi_print('eprew_eval', eval_rew_mean)
            mpi_print('fps', fps)
            mpi_print('total_timesteps', update*nbatch)
            mpi_print([epinfo['r'] for epinfo in epinfobuf10])

            rep_loss = 0
            if len(mblossvals):
                for (lossval, lossname) in zip(lossvals, model.loss_names):
                    mpi_print(lossname, lossval)
                    tb_writer.log_scalar(lossval, lossname, step=step)
            mpi_print('----\n')

            wandb.log({"%s/eprew"%(Config.ENVIRONMENT):rew_mean_10,
                        "%s/eprew_eval"%(Config.ENVIRONMENT):eval_rew_mean,
                        "%s/custom_step"%(Config.ENVIRONMENT):step})
        if can_save:
            if save_interval and (update % save_interval == 0):
                save_model()

            for j, checkpoint in enumerate(checkpoints):
                if (not saved_key_checkpoints[j]) and (step >= (checkpoint * 1e6)):
                    saved_key_checkpoints[j] = True
                    save_model(str(checkpoint) + 'M')

    save_model()

    env.close()
    return mean_rewards
コード例 #25
0
def train(config, train_dataloader, dev_dataloader):
    # bert_config = BertConfig.from_pretrained(config['model_path'])
    model = BertForSequenceClassification.from_pretrained(config['model_path'])

    wandb.watch(model)

    optimizer = AdamW(model.parameters(), lr=config['learning_rate'])
    # lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='max', factor=0.5,
    #                                  patience=2, verbose=True)
    model.to(config['device'])
    # fgm = FGM(model)
    pgd = PGD(model)
    K = 3
    epoch_iterator = trange(config['num_epochs'])
    global_steps = 0
    train_loss = 0.
    logging_loss = 0.
    best_roc_auc = 0.
    best_model_path = ''

    if config['n_gpus'] > 1:
        model = nn.DataParallel(model)

    for epoch in epoch_iterator:

        train_iterator = tqdm(train_dataloader,
                              desc='Training',
                              total=len(train_dataloader))
        model.train()
        for batch in train_iterator:
            batch_cuda = {
                item: value.to(config['device'])
                for item, value in list(batch.items())
            }
            loss = model(**batch_cuda)[0]

            if config['n_gpus'] > 1:
                loss = loss.mean()

            model.zero_grad()
            loss.backward()

            pgd.backup_grad()
            for t in range(K):
                pgd.attack(is_first_attack=(t == 0))
                if t != K - 1:
                    model.zero_grad()
                else:
                    pgd.restore_grad()
                loss_adv = model(**batch_cuda)[0]
                if config['n_gpus'] > 1:
                    loss_adv = loss_adv.mean()
                loss_adv.backward()
            pgd.restore()

            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item()
            global_steps += 1

            train_iterator.set_postfix_str(
                f'running training loss: {loss.item():.4f}')
            wandb.log({'running training loss': loss.item()},
                      step=global_steps)

            if global_steps % config['logging_step'] == 0:
                print_train_loss = (train_loss -
                                    logging_loss) / config['logging_step']
                logging_loss = train_loss

                val_loss, roc_auc_score, pr_auc_score, acc = evaluation(
                    config, model, dev_dataloader)

                print_log = f'>>> training loss: {print_train_loss:.4f}, valid loss: {val_loss:.4f}, '
                # lr_scheduler.step(metrics=roc_auc_score, epoch=global_steps // config['logging_step'])

                if roc_auc_score > best_roc_auc:
                    model_save_path = os.path.join(
                        config['output_path'],
                        f'checkpoint-{global_steps}-{roc_auc_score:.3f}')
                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    model_to_save.save_pretrained(model_save_path)
                    best_roc_auc = roc_auc_score
                    best_model_path = model_save_path

                print_log += f'valid roc-auc: {roc_auc_score:.3f}, valid pr-auc: {pr_auc_score}, valid acc: {acc:.3f}'
                print(print_log)
                log_wandb_metrics = {
                    'training loss': print_train_loss,
                    'valid loss': val_loss,
                    'valid roc-auc': roc_auc_score,
                    'valid pr-auc': pr_auc_score,
                    'valid acc': acc
                }
                wandb.log(log_wandb_metrics, step=global_steps)
                model.train()

    return model, best_model_path
コード例 #26
0
def _train(num_epochs,
           loaders,
           model,
           optimizer,
           criterion,
           save_path,
           min_loss=np.Inf):
    """returns trained model"""
    # initialize tracker for minimum validation loss
    val_loss_min = min_loss
    num_train_iters = ceil(
        len(loaders['train'].dataset) / loaders['train'].batch_size)
    num_val_iters = ceil(
        len(loaders['val'].dataset) / loaders['val'].batch_size)

    for epoch in range(1, num_epochs + 1):
        # initialize variables to monitor training and validation loss
        train_loss, val_loss = 0.0, 0.0
        train_metrics, val_metrics = np.zeros(2), np.zeros(2)

        # training the model
        model.train()
        for data, labels in loaders['train']:
            # move data, labels to run_device
            data = data.to(run_device)
            labels = [label.to(run_device) for label in labels]

            # forward pass, backward pass and update weights
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # calculate training loss and metrics
            train_loss += loss.item()
            train_metrics += compute_metrics(outputs, labels)

        # evaluating the model
        model.eval()
        for data, labels in loaders['val']:
            # move data, labels to run_device
            data = data.to(run_device)
            labels = [label.to(run_device) for label in labels]

            # forward pass without grad to calculate the validation loss
            with torch.no_grad():
                outputs = model(data)
                loss = criterion(outputs, labels)

            # calculate validation loss
            val_loss += loss.item()
            val_metrics += compute_metrics(outputs, labels)

        # compute average loss and accuracy
        train_loss /= num_train_iters
        val_loss /= num_val_iters
        train_metrics *= 100 / num_train_iters
        val_metrics *= 100 / num_val_iters

        # logging metrics to wandb
        wandb.log({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_fscore_gender': train_metrics[0],
            'train_fscore_accent': train_metrics[1],
            'val_loss': val_loss,
            'val_fscore_gender': val_metrics[0],
            'val_fscore_accent': val_metrics[1]
        })

        # print training & validation statistics
        print(
            "Epoch: {}\tTraining Loss: {:.6f}\tTraining F-score: {}\tValidation Loss: {:.6f}\tValidation F-score: {}"
            .format(epoch, train_loss, train_metrics, val_loss, val_metrics))

        # saving the model when validation loss decreases
        if val_loss <= val_loss_min:
            print(
                "Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ..."
                .format(val_loss_min, val_loss))
            torch.save(
                model.module if isinstance(model, DataParallel) else
                model.state_dict(), save_path)
            val_loss_min = val_loss
コード例 #27
0
def main():
    """Training routing for the Beyonder Network.
    """
    args = parse_args()
    torch.manual_seed(0)
    exp_name = f'beymax-{args.blocks}-{args.nhead}-{args.nhid}-{args.noutput}reg'
    wandb.init(project='DeepMetaLearning', name=exp_name, config=args)

    base_data_train = DataLoader(BaseDataDataset("data/train/", args.nrows,
                                                 args.ncols),
                                 batch_size=args.batch_size,
                                 shuffle=True, num_workers=8)
    base_data_valid = DataLoader(BaseDataDataset("data/valid/", args.nrows,
                                                 args.ncols),
                                 batch_size=args.batch_size,
                                 num_workers=8)

    total_steps = len(base_data_train)*args.epochs

    model = AttentionMetaExtractor(args.nrows, args.noutput, args.nhead,
                                   args.nhid, args.blocks, dropout=args.dropout)
    model.to(args.device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate,
                                  amsgrad=True)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
                                                    total_steps=total_steps)

    best_loss = float("inf")
    progress_bar = tqdm(range(total_steps))
    for epoch in range(args.epochs):
        model.train()
        train_loss = []
        for batch in base_data_train:
            x, y = [tensor.to(args.device) for tensor in batch]
            output = model(x)
            loss = F.cross_entropy(output, y)
            train_loss.append(loss.item())
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
        mloss = np.mean(train_loss)
        wandb.log({"train/loss": mloss, "epoch": epoch})

        model.eval()
        valid_loss = []
        for batch in base_data_valid:
            x, y = [tensor.to(args.device) for tensor in batch]
            output = model(x)
            loss = F.cross_entropy(output, y)
            valid_loss.append(loss.item())
        mloss = np.mean(valid_loss)
        wandb.log({"valid/loss": mloss, "epoch": epoch})
        if mloss < best_loss:
            best_loss = mloss
            output_dir = pathlib.Path(f"model")
            output_dir.mkdir(exist_ok=True)
            best_name = f"best-{exp_name}-{epoch}-{mloss:.5f}.pth"
            torch.save(model.state_dict(), output_dir/best_name)
    model.load_state_dict(torch.load(output_dir/best_name))
    model.eval()
    ytrue = []
    yhat = []
    for batch in base_data_valid:
        x, y = [tensor.to(args.device) for tensor in batch]
        ytrue += y.tolist()
        output = model(x)
        yhat += output.argmax(dim=1).tolist()
    recall = metrics.recall_score(ytrue, yhat, average="micro")
    precis = metrics.precision_score(ytrue, yhat, average="micro")
    wandb.log({"recall": recall})
    wandb.log({"precision": precis})
コード例 #28
0
    def forward(self,x=None, v0=None, q0=None, rhoNp1=None, \
                vf=None, a_var=None, rhocr=None, g_var=None, future_r=None, future_s=None,\
                epsq=None, epsv=None,\
                t_var=None, tau=None, nu=None, delta=None, kappa=None,\
                cap_delta=None, lambda_var=None):
        self.print_count += 1
        # offramp_prop=None,
        if v0 is not None: self.v0 = v0.view(-1, 1)
        if q0 is not None: self.q0 = q0.view(-1, 1)
        if rhoNp1 is not None: self.rhoNp1 = rhoNp1.view(-1, 1)
        if vf is not None: self.vf = torch.mean(vf)
        if a_var is not None: self.a_var = torch.mean(a_var).view(-1, 1)
        if rhocr is not None: self.rhocr = torch.mean(rhocr).view(-1, 1)
        if g_var is not None: self.g_var = torch.mean(g_var).view(-1, 1)
        if future_r is not None: self.future_r = future_r
        if future_s is not None: self.future_s = future_s
        # if offramp_prop is not None: self.offramp_prop = offramp_prop
        if epsq is not None: self.epsq = epsq.view(-1, 1)
        if epsv is not None: self.epsv = epsv.view(-1, 1)
        if t_var is not None: self.t_var = t_var.view(-1, 1)
        if tau is not None: self.tau = tau.view(-1, 1)
        if nu is not None: self.nu = nu.view(-1, 1)
        if delta is not None: self.delta = delta.view(-1, 1)
        if kappa is not None: self.kappa = kappa.view(-1, 1)
        if cap_delta is not None: self.cap_delta = cap_delta
        if lambda_var is not None: self.lambda_var = lambda_var

        x = x.view(-1, self.num_segments, self.inputs_per_segment)

        self.current_densities = x[:, :, self.
                                   rho_index]  #/ self.lambda_var #* (self.g_var+1e-6)#/ (((100.*self.g_var/1000.))))#*self.lambda_var+self.TINY))
        # self.current_flows = x[:, :, self.q_index] #/ self.lambda_var #+ self.epsq #########
        # density = veh/km
        # flow = veh/h
        # vel = km/h
        self.current_onramp = self.active_onramps.float() * x[:, :,
                                                              self.r_index]
        self.current_offramp = self.active_offramps.float() * x[:, :,
                                                                self.s_index]

        self.current_velocities = x[:, :, self.v_index]
        # self.current_velocities = self.current_flows / (self.current_densities*self.lambda_var+self.TINY)
        self.current_velocities = torch.clamp(self.current_velocities,
                                              min=self.vmin,
                                              max=self.vmax)
        self.current_densities = torch.clamp(self.current_densities,
                                             min=0.,
                                             max=1000.)

        self.current_flows = self.current_velocities * (
            self.current_densities * self.lambda_var)

        self.current_flows = torch.clamp(self.current_flows,
                                         min=0.,
                                         max=10000.)
        self.current_onramp = torch.clamp(self.current_onramp,
                                          min=0.,
                                          max=5000.)
        self.current_offramp = torch.clamp(self.current_offramp,
                                           min=0.,
                                           max=5000.)
        self.v0 = torch.clamp(self.v0, min=self.vmin, max=self.vmax)

        self.prev_velocities = torch.cat(
            [self.v0, self.current_velocities[:, :-1]], dim=1)
        self.next_densities = torch.cat(
            [self.current_densities[:, 1:], self.rhoNp1], dim=1)
        self.prev_flows = torch.cat([self.q0, self.current_flows[:, :-1]],
                                    dim=1)

        future_velocities = self.future_v()
        future_velocities = torch.clamp(future_velocities,
                                        min=self.vmin,
                                        max=self.vmax)
        future_densities = self.future_rho()
        #future_occupancies = (future_densities) / (self.g_var+1e-6)#* (100*self.g_var/1000) #* self.lambda_var
        # future_occupancies = (future_densities / self.lambda_var) / (self.g_var+1e-6)

        future_flows = future_densities * future_velocities * self.lambda_var  #- self.epsq

        #old future_s = self.active_offramps * (self.offramp_prop*self.current_flows) #active_offramps.float() *
        # future_s = self.active_offramps * (self.offramp_prop*self.prev_flows) #active_offramps.float() *
        future_s = self.active_offramps.float() * future_s

        future_r = self.active_onramps.float() * future_r

        try:
            if self.print_count % self.print_every == 0:
                wandb.log({
                    "future_velocities":
                    wandb.Histogram(future_velocities.cpu().detach().numpy())
                })
                wandb.log({
                    "future_densities":
                    wandb.Histogram(future_densities.cpu().detach().numpy())
                })
                #wandb.log({"future_occupancies": wandb.Histogram(future_occupancies.cpu().detach().numpy())})
                wandb.log({
                    "future_flows":
                    wandb.Histogram(future_flows.cpu().detach().numpy())
                })
                wandb.log({
                    "future_r":
                    wandb.Histogram(future_r.cpu().detach().numpy())
                })
                wandb.log({
                    "future_s":
                    wandb.Histogram(future_s.cpu().detach().numpy())
                })
                wandb.log({
                    'mean_future_velocities':
                    future_velocities.cpu().detach().numpy().mean(),
                    'mean_future_densities':
                    future_densities.cpu().detach().numpy().mean(),
                    'mean_future_flows':
                    future_flows.cpu().detach().numpy().mean()
                })
                wandb.log({
                    "future_r_4":
                    wandb.Histogram(future_r[:, 3].cpu().detach().numpy())
                })
                wandb.log({
                    "future_s_2":
                    wandb.Histogram(future_s[:, 1].cpu().detach().numpy())
                })
                wandb.log({
                    "future_flows_1":
                    wandb.Histogram(future_flows[:, 0].cpu().detach().numpy())
                })
                wandb.log({
                    "future_flows_2":
                    wandb.Histogram(future_flows[:, 1].cpu().detach().numpy())
                })
                wandb.log({
                    "future_flows_3":
                    wandb.Histogram(future_flows[:, 2].cpu().detach().numpy())
                })
                wandb.log({
                    "future_flows_4":
                    wandb.Histogram(future_flows[:, 3].cpu().detach().numpy())
                })
                wandb.log({
                    "future_flows_1_to_2":
                    wandb.Histogram(future_flows[:, 1].cpu().detach().numpy() -
                                    future_flows[:, 0].cpu().detach().numpy())
                })
                wandb.log({
                    "future_flows_2_to_3":
                    wandb.Histogram(future_flows[:, 2].cpu().detach().numpy() -
                                    future_flows[:, 1].cpu().detach().numpy())
                })
                wandb.log({
                    "future_flows_3_to_4":
                    wandb.Histogram(future_flows[:, 3].cpu().detach().numpy() -
                                    future_flows[:, 2].cpu().detach().numpy())
                })

                wandb.log({
                    "flow_residual":
                    wandb.Histogram(self.flow_residual.cpu().detach().numpy())
                })

                wandb.log({
                    "epsq":
                    wandb.Histogram(self.epsq.cpu().detach().numpy())
                })
        except Exception as e:
            print(e)

        # future_densities = future_densities * self.lambda_var
        future_velocities = torch.clamp(future_velocities, min=0, max=120)
        future_densities = torch.clamp(future_densities, min=0, max=1000)
        # future_occupancies = torch.clamp(future_occupancies, min=0, max=100)
        future_flows = torch.clamp(future_flows, min=0, max=10000)
        future_r = torch.clamp(future_r, min=0, max=10000)
        future_s = torch.clamp(future_s, min=0, max=10000)

        # one_stack =  torch.stack((future_flows,future_occupancies,future_r,future_s),dim=2)

        # one_stack =  torch.stack((future_flows,future_densities,future_velocities,future_r,future_s),dim=2)
        one_stack = torch.stack(
            (future_densities, future_velocities, future_r, future_s), dim=2)

        return one_stack.view(-1, self.num_segments *
                              self.inputs_per_segment), self.flow_residual
コード例 #29
0
def train_256(epoch, state_dict, model, optimizer, train_loader, valid_loader,
              args, logger):
    model.train()

    # Train loop
    for data in tqdm(train_loader):
        start_time = time.time()
        optimizer.zero_grad()

        #Downsample then reconstruct upsampled version
        x = data[0]
        y = F.interpolate(F.interpolate(x,
                                        args.low_resolution,
                                        mode="bilinear"),
                          args.image_size,
                          mode="bilinear")
        x = x.to(args.device)
        y = y.to(args.device)
        x_mask = x - y
        x_mask_hat = model(y)
        x_hat = y + x_mask_hat

        #Compute loss and take step
        loss = loss_func(x_mask_hat, x_mask)
        loss.backward()
        optimizer.step()

        # Calculate iteration time
        end_time = time.time()
        itr_time = end_time - start_time

        # Update logger & wandb
        logger.update(state_dict['itr'], loss.cpu().item(), itr_time)
        wandb.log({'train_loss': loss.item()}, commit=False)
        wandb.log({'train_itr_time': itr_time}, commit=True)

        # Save images, logger, weights on save_every interval
        if not state_dict['itr'] % args.save_every:
            # Save images
            save_image(
                x_mask.cpu(), args.output_dir +
                'train_real_mask_itr{}.png'.format(state_dict['itr']))
            save_image(
                x_mask_hat.cpu(), args.output_dir +
                'train_recon_mask_itr{}.png'.format(state_dict['itr']))
            save_image(
                y.cpu(), args.output_dir +
                'train_low_img_256_itr{}.png'.format(state_dict['itr']))
            save_image(
                x_hat.cpu(), args.output_dir +
                'train_recon_img_256_itr{}.png'.format(state_dict['itr']))
            save_image(
                x.cpu(), args.output_dir +
                'train_real_img_256_itr{}.png'.format(state_dict['itr']))

            # Save model & optimizer weights
            torch.save(
                {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, args.output_dir +
                '/UNET_pixel_model_256_itr{}.pth'.format(state_dict['itr']))

            # Save logger
            torch.save(logger, args.output_dir + '/logger.pth')

        if not state_dict['itr'] % args.valid_every and state_dict['itr'] != 0:
            print("here")
            model.eval()
            val_losses = []

            with torch.no_grad():
                for data in tqdm(valid_loader):
                    print("here")
                    x_val = data[0]
                    y_val = F.interpolate(F.interpolate(x_val,
                                                        args.low_resolution,
                                                        mode="bilinear"),
                                          args.image_size,
                                          mode="bilinear")
                    x_val = x_val.to(args.device)
                    y_val = y_val.to(args.device)
                    x_mask_val = x_val - y_val
                    x_mask_hat_val = model(y_val)
                    x_hat_val = y_val + x_mask_hat_val
                    loss_val = loss_func(x_mask_hat_val, x_mask_val)
                    val_losses.append(loss_val.item())

                save_image(
                    x_mask_val.cpu(), args.output_dir +
                    'val_real_mask_itr{}.png'.format(state_dict['itr']))
                save_image(
                    x_mask_hat_val.cpu(), args.output_dir +
                    'val_recon_mask_itr{}.png'.format(state_dict['itr']))
                save_image(
                    y_val.cpu(), args.output_dir +
                    'val_low_img_256_itr{}.png'.format(state_dict['itr']))
                save_image(
                    x_hat_val.cpu(), args.output_dir +
                    'val_recon_img_256_itr{}.png'.format(state_dict['itr']))
                save_image(
                    x_val.cpu(), args.output_dir +
                    'val_real_img_256_itr{}.png'.format(state_dict['itr']))

                val_losses_mean = np.mean(val_losses)
                wandb.log({'val_loss': val_losses_mean}, commit=True)
                logger.update_val_loss(state_dict['itr'], val_losses_mean)
                val_losses.clear()

            model.train()
        # Increment iteration number
        state_dict['itr'] += 1
コード例 #30
0
ファイル: reporting_wandb.py プロジェクト: dpressel/baseline
 def step(self, metrics, tick, phase, tick_type=None, **kwargs):
     metrics = {'{}/{}'.format(phase,key):metrics[key] for key in metrics}
     wandb.log(metrics)
コード例 #31
0
def main(config):
    torch.manual_seed(config.seed)
    random.seed(config.seed)
    np.random.seed(config.seed)

    cuda = not config.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    # define data paths
    decoder_path = config.models_dir / 'decoder' / config.env / config.name
    decoder_path.mkdir(parents=True, exist_ok=True)

    encoder_path = config.encoder_path
    if encoder_path is None:
        encoder_path = config.models_dir / 'encoder' / config.env / config.name / 'encoder.pt'

    log_path = config.logdir / 'decoder' / config.env / config.name
    log_path.mkdir(parents=True, exist_ok=True)
    util.write_options(config, log_path)

    # builds a dataset by stepping a gym env with random actions
    dataset = gym_dataset.load_or_generate(config)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=0)

    # load encoder
    encoder = torch.load(encoder_path)

    # calculate the sizes of everything
    epoch_size = len(dataset)
    action_space = dataset.env.action_space
    action_size = util.prod(action_space.shape)
    if config.embed_size is None:
        embed_size = action_size
    else:
        embed_size = config.embed_size

    # define decoder
    decoder = ActionDecoder(
        config.n_layers,
        embed_size,
        config.traj_len,
        action_space).to(device)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=config.lr)

    # calculate necessary statistics
    z_stats = marginal_stats(train_loader, encoder, device)

    if config.wandb:
        wandb.init(project=config.wandb_proj, entity='vinnibuh')
        wandb.config.update(util.clean_config(config))

    for epoch in range(config.epochs):
        decoder_loss = 0
        decoder_recon_loss = 0
        decoder_norm_loss = 0

        temp_batch = 0
        temp_loss = 0
        temp_recon_loss = 0
        temp_norm_loss = 0

        for batch_idx, (states, actions) in enumerate(train_loader):
            z = turn_to_z(actions, encoder, device)
            z = (z - z_stats[0].detach()) / z_stats[1].detach()
            decoded_action = decoder(z)
            z_hat = encoder.encode(decoded_action)[0]
            z_hat_white = (z_hat - z_stats[0].detach()) / z_stats[1].detach()

            recon_loss = F.mse_loss(z_hat_white, z)
            norm_loss = decoded_action.norm(dim=2).sum()
            loss = recon_loss + config.norm_loss * norm_loss

            decoder_optimizer.zero_grad()
            loss.backward()
            decoder_optimizer.step()

            decoder_loss += loss.item()
            decoder_recon_loss += recon_loss.item()
            decoder_norm_loss += norm_loss.item()

            temp_loss += loss.item()
            temp_recon_loss += recon_loss.item()
            temp_norm_loss += norm_loss.item()

            if batch_idx > 0 and (batch_idx * config.batch_size) % config.log_interval < config.batch_size:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * config.batch_size, epoch_size,
                    100. * batch_idx / len(train_loader),
                    loss.item() / config.batch_size))
                temp_size = (batch_idx - temp_batch) * config.batch_size
                temp_batch = batch_idx
                wandb.log({'epoch_progress': 100. * batch_idx / len(train_loader)})
                wandb.log({'mean batch loss': temp_loss / temp_size})
                wandb.log({'mean batch recon loss': temp_recon_loss / temp_size})
                wandb.log({'mean batch norm loss': temp_norm_loss / temp_size})
                temp_loss = 0
                temp_recon_loss = 0
                temp_norm_loss = 0

        print((
            'ActionDecoder epoch: {}\tAverage loss: {:.4f}'
            '\tRecon loss: {:.6f}\tNorm loss: {:.6f}'
        ).format(
            epoch, decoder_loss / epoch_size,
            decoder_recon_loss / epoch_size,
            decoder_norm_loss / epoch_size))
        wandb.log({'epoch': epoch})
        wandb.log({'mean epoch loss': decoder_loss / epoch_size})
        wandb.log({'mean epoch recon loss': decoder_recon_loss / epoch_size})
        wandb.log({'mean epoch norm loss': decoder_norm_loss / epoch_size})

    # z_stats[2] is the max
    decoder.mean_z = z_stats[0]
    decoder.std_z = z_stats[1]
    decoder.max_embedding = z_stats[2]
    torch.save(decoder, decoder_path / 'decoder.pt')