Conversation
chaserileyroberts
left a comment
There was a problem hiding this comment.
Looks good! Thanks for the fix! Approval after small comment.
|
|
||
| def concat(self, values: Sequence[Tensor], axis) -> Tensor: | ||
| def shape_concat(self, values: Sequence[Tensor], axis) -> Tensor: | ||
| """Concatenate a sequence of tensors together about the given axis.""" |
There was a problem hiding this comment.
Can you change the description here to be explicitly only for shape calculations?
There was a problem hiding this comment.
Added to the description, now reads:
"""Concatenate a sequence of tensors together about the given axis,
intended only for use in shape calculations"""
|
Hmm. Still seems like there's a shape based error for the Pytorch backend. Can you investigate? |
| return np.concatenate(values, axis) | ||
|
|
||
| def concat(self, values: Tensor, axis: int = 0) -> Tensor: | ||
| return np.stack(values, axis) |
There was a problem hiding this comment.
Aw, that would explain it. You'll need to use pytorch's concat equivalent here instead of numpy.
| def concat(self, values: Sequence[Tensor], axis: int = 0) -> Tensor: | ||
| new_shape = None | ||
| if axis == 0: | ||
| new_shape = ShellTensor(values) | ||
| else: | ||
| new_shape = self.shape_concat(values, axis) | ||
| return new_shape | ||
|
|
There was a problem hiding this comment.
I don't this this is correct. ShellTensor is a tensor type that only stores its shape and has no concrete values, so when you do a normal concatenation of multiple ShellTensors, you'll want to add the value of that specified axis together.
Though this function really isn't important at all. Just throw a NotImplementedError and we'll add it when we actually need it.
| return result | ||
|
|
||
| def concat(self, values: Tensor, axis: int) -> Tensor: | ||
| def shape_concat(self, values: Tensor, axis: int) -> Tensor: |
There was a problem hiding this comment.
Doubly come to think of it, we don't need the axis number for shape_concat at all since we always use just -1. Can you remove this argument?
|
It looks as though the Travis build for python 3.7 failed due to an error when downloading pytype. The Travis build for python 3.6 successfully completed. |
Addresses issue #350, renames existing
backend.concatmethods tobackend.shape_concatand implementbackend.concatmethods using backendstackmethod. If namebackend.shape_concatis too similar to existingbackend.concat_shapemethods could alternatively foldbackend.concatinto existing method of same name as special case whenaxis == 0. Did not squash commits.